diff --git a/README.md b/README.md index 95286e3..22fe43c 100644 --- a/README.md +++ b/README.md @@ -258,13 +258,14 @@ trainer = GraphTrainer(df) #extract features features = trainer.prepare_and_train() +from xpotato.dataset.utils import save_dataframe from sklearn.model_selection import train_test_split train, val = train_test_split(df, test_size=0.2, random_state=1234) #save train and validation, this is important for the frontend to work -train.to_pickle("train_dataset") -val.to_pickle("val_dataset") +save_dataframe(train, 'train.tsv') +save_dataframe(val, 'val.tsv') import json @@ -287,18 +288,18 @@ with open("graphs.pickle", "wb") as f: If the DataFrame is ready with the parsed graphs, the UI can be started to inspect the extracted rules and modify them. The frontend is a streamlit app, the simplest way of starting it is (the training and the validation dataset must be provided): ``` -streamlit run frontend/app.py -- -t notebooks/train_dataset -v notebooks/val_dataset -g ud +streamlit run frontend/app.py -- -t notebooks/train.tsv -v notebooks/val.tsv -g ud ``` it can be also started with the extracted features: ``` -streamlit run frontend/app.py -- -t notebooks/train_dataset -v notebooks/val_dataset -g ud -sr notebooks/features.json +streamlit run frontend/app.py -- -t notebooks/train.tsv -v notebooks/val.tsv -g ud -sr notebooks/features.json ``` if you already used the UI and extracted the features manually and you want to load it, you can run: ``` -streamlit run frontend/app.py -- -t notebooks/train_dataset -v notebooks/val_dataset -g ud -sr notebooks/features.json -hr notebooks/manual_features.json +streamlit run frontend/app.py -- -t notebooks/train.tsv -v notebooks/val.tsv -g ud -sr notebooks/features.json -hr notebooks/manual_features.json ``` ### Advanced mode @@ -331,7 +332,7 @@ sentences = [("Governments and industries in nations around the world are pourin Then, the frontend can be started: ``` -streamlit run frontend/app.py -- -t notebooks/unsupervised_dataset -g ud -m advanced +streamlit run frontend/app.py -- -t notebooks/unsupervised_dataset.tsv -g ud -m advanced ``` Once the frontend starts up and you define the labels, you are faced with the annotation interface. You can search elements by clicking on the appropriate column name and applying the desired filter. You can annotate instances by checking the checkbox at the beginning of the line. You can check multiple checkboxs at a time. Once you've selected the utterances you want to annotate, click on the _Annotate_ button. The annotated samples will appear in the lower table. You can clear the annotation of certain elements by selecting them in the second table and clicking _Clear annotation_. @@ -345,7 +346,7 @@ Once you have some annotated data, you can train rules by clicking the _Train!_ If you have the features ready and you want to evaluate them on a test set, you can run: ```python -python scripts/evaluate.py -t ud -f notebooks/features.json -d notebooks/val_dataset +python scripts/evaluate.py -t ud -f notebooks/features.json -d notebooks/val.tsv ``` The result will be a _csv_ file with the labels and the matched rules. diff --git a/features/crowdtruth/README.md b/features/crowdtruth/README.md index 56f63ea..6275f4d 100644 --- a/features/crowdtruth/README.md +++ b/features/crowdtruth/README.md @@ -15,11 +15,11 @@ Prebuilt rule-systems for both the _cause_ and the _treat_ label are also availa Then the frontend of POTATO can be started from the __frontend__ directory: ```bash -streamlit run app.py -- -t ../features/crowdtruth/crowdtruth_train_dataset_cause_ud.pickle -v ../features/crowdtruth/crowdtruth_dev_dataset_cause_ud.pickle -hr ../features/crowdtruth/crowd_cause_features_ud.json +streamlit run app.py -- -t ../features/crowdtruth/crowdtruth_train_dataset_cause_ud.tsv -v ../features/crowdtruth/crowdtruth_dev_dataset_cause_ud.tsv -hr ../features/crowdtruth/crowd_cause_features_ud.json ``` If you are done building the rule-system, you can evaluate it on the test data, for this run _evaluate.py_ from the _scripts_ directory. ```bash -python evaluate.py -t ud -f ../features/crowdtruth/crowd_cause_features_ud.json -d ../features/crowdtruth/crowdtruth_train_dataset_cause_ud.pickle +python evaluate.py -t ud -f ../features/crowdtruth/crowd_cause_features_ud.json -d ../features/crowdtruth/crowdtruth_train_dataset_cause_ud.tsv ``` \ No newline at end of file diff --git a/features/crowdtruth/crowdtruth.ipynb b/features/crowdtruth/crowdtruth.ipynb index 375cbd4..e3faed1 100644 --- a/features/crowdtruth/crowdtruth.ipynb +++ b/features/crowdtruth/crowdtruth.ipynb @@ -1,5 +1,19 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "77655d7b", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -nc -q -O \"ground_truth_cause.csv\" \"https://raw.githubusercontent.com/CrowdTruth/Medical-Relation-Extraction/master/ground_truth_cause.csv\"\n", + "!wget -nc -q -O \"ground_truth_treat.csv\" \"https://raw.githubusercontent.com/CrowdTruth/Medical-Relation-Extraction/master/ground_truth_treat.csv\"\n", + "!wget -nc -q -O \"ground_truth_cause.xlsx\" \"https://github.com/CrowdTruth/Medical-Relation-Extraction/blob/master/train_dev_test/ground_truth_cause.xlsx?raw=true\"\n", + "!wget -nc -q -O \"ground_truth_treat.xlsx\" \"https://github.com/CrowdTruth/Medical-Relation-Extraction/blob/master/train_dev_test/ground_truth_treat.xlsx?raw=true\"\n", + "!wget -nc -q -O \"food_disease_dataset.csv\" \"https://raw.githubusercontent.com/gjorgjinac/food-disease-dataset/main/food_disease_dataset.csv\"" + ] + }, { "cell_type": "code", "execution_count": 16, @@ -324,16 +338,16 @@ "metadata": {}, "outputs": [], "source": [ + "\n", + "from xpotato.dataset.utils import save_dataframe\n", + "\n", "train_df = train_dataset.to_dataframe()\n", "dev_df = dev_dataset.to_dataframe()\n", "test_df = test_dataset.to_dataframe()\n", "\n", - "#train_df.to_pickle(\"crowdtruth_train_dataset_treat_fourlang.pickle\")\n", - "#dev_df.to_pickle(\"crowdtruth_dev_dataset_treat_fourlang.pickle\")\n", - "#test_df.to_pickle(\"crowdtruth_test_dataset_treat_fourlang.pickle\")\n", - "train_df.to_pickle(\"crowdtruth_train_dataset_cause_fourlang.pickle\")\n", - "dev_df.to_pickle(\"crowdtruth_dev_dataset_cause_fourlang.pickle\")\n", - "test_df.to_pickle(\"crowdtruth_test_dataset_cause_fourlang.pickle\")" + "save_dataframe(train_df, \"crowdtruth_train_dataset_cause_fourlang.tsv\")\n", + "save_dataframe(dev_df, \"crowdtruth_dev_dataset_cause_fourlang.tsv\")\n", + "save_dataframe(test_df, \"crowdtruth_test_dataset_cause_fourlang.tsv\")" ] }, { diff --git a/features/crowdtruth/data.sh b/features/crowdtruth/data.sh index 1e80806..1a1b04b 100644 --- a/features/crowdtruth/data.sh +++ b/features/crowdtruth/data.sh @@ -1,12 +1,12 @@ -wget https://owncloud.tuwien.ac.at/index.php/s/z3IMX2fUNM7Kw6i/download -O crowdtruth_dev_dataset_cause_fourlang.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/C4MOznjvxpcU5Ik/download -O crowdtruth_dev_dataset_cause_ud.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/39s2AsFYTL3Keni/download -O crowdtruth_dev_dataset_treat_fourlang.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/RqC1SzWhRXoKOnn/download -O crowdtruth_dev_dataset_treat_ud.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/WpxGeblkiEhkIib/download -O crowdtruth_test_dataset_cause_fourlang.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/wro8yTxXYK6WpF8/download -O crowdtruth_test_dataset_cause_ud.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/xLz0fOxjb8ORBlR/download -O crowdtruth_test_dataset_treat_fourlang.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/RaCcWl0xVdVpPQZ/download -O crowdtruth_test_dataset_treat_ud.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/i7BuiCMvYWcZlI1/download -O crowdtruth_train_dataset_cause_fourlang.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/NAHY0g1XqYM28LQ/download -O crowdtruth_train_dataset_cause_ud.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/OPzP4kgD4PVwZOA/download -O crowdtruth_train_dataset_treat_fourlang.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/sL3s3uaUgnLdKsy/download -O crowdtruth_train_dataset_treat_ud.pickle +wget https://owncloud.tuwien.ac.at/index.php/s/aHX8ByPg8nN3W5v/download -O crowdtruth_dev_dataset_cause_fourlang.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/1P1OppoaeFPk4iI/download -O crowdtruth_dev_dataset_cause_ud.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/imAYGbrNVtTHCRs/download -O crowdtruth_dev_dataset_treat_fourlang.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/oOOZhWVjC40xxQm/download -O crowdtruth_dev_dataset_treat_ud.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/C2SQeWPqDdQrtXQ/download -O crowdtruth_test_dataset_cause_fourlang.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/3PGrMU6SINTSbfl/download -O crowdtruth_test_dataset_cause_ud.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/hDyM5x4XCcqANt3/download -O crowdtruth_test_dataset_treat_fourlang.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/SGv5zZm5UyulXT1/download -O crowdtruth_test_dataset_treat_ud.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/KcpBVwigbB19H56/download -O crowdtruth_train_dataset_cause_fourlang.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/tjLqzSUl0zU32zu/download -O crowdtruth_train_dataset_cause_ud.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/0cDVR9nz0I4QWvp/download -O crowdtruth_train_dataset_treat_fourlang.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/PTOhXBqxrLmrAzW/download -O crowdtruth_train_dataset_treat_ud.tsv diff --git a/features/food/README.md b/features/food/README.md index 90484de..984ab0a 100644 --- a/features/food/README.md +++ b/features/food/README.md @@ -15,11 +15,11 @@ Prebuilt rule-systems for both the _cause_ and the _treat_ label are also availa Then the frontend of POTATO can be started from the __frontend__ directory: ```bash -streamlit run app.py -- -t ../features/food/food_train_dataset_cause_ud.pickle -v ../features/food/food_dev_dataset_cause_ud.pickle -hr ../features/crowdtruth/food_cause_features_ud.json +streamlit run app.py -- -t ../features/food/food_train_dataset_cause_ud.tsv -v ../features/food/food_dev_dataset_cause_ud.tsv -hr ../features/crowdtruth/food_cause_features_ud.json ``` If you are done building the rule-system, you can evaluate it on the test data, for this run _evaluate.py_ from the _scripts_ directory. ```bash -python evaluate.py -t ud -f ../features/food/food_cause_features_ud.json -d ../features/crowdtruth/food_train_dataset_cause_ud.pickle +python evaluate.py -t ud -f ../features/food/food_cause_features_ud.json -d ../features/crowdtruth/food_train_dataset_cause_ud.tsv ``` \ No newline at end of file diff --git a/features/food/data.sh b/features/food/data.sh index aa21fdd..34a183b 100644 --- a/features/food/data.sh +++ b/features/food/data.sh @@ -1,8 +1,8 @@ -wget https://owncloud.tuwien.ac.at/index.php/s/G8pbpWQq6bqYbXp/download -O food_dev_dataset_cause_fourlang.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/zNlkmijP6T0bRT5/download -O food_dev_dataset_cause_ud.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/lJIRnQBkhyn8bQs/download -O food_dev_dataset_treat_fourlang.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/Nj9vpcBs2C4aFMW/download -O food_dev_dataset_treat_ud.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/WFoTXbRrtn1QDqT/download -O food_test_dataset_cause_fourlang.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/dEvaQhhCQ39e2hv/download -O food_test_dataset_cause_ud.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/A9U3iz5SzGwmdW6/download -O food_test_dataset_treat_fourlang.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/d4Q09GVI89XwKuD/download -O food_test_dataset_treat_ud.pickle +wget https://owncloud.tuwien.ac.at/index.php/s/eQHmVCULV3sYVKF/download -O food_dev_dataset_cause_fourlang.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/Jem0O20atHYJYkf/download -O food_dev_dataset_cause_ud.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/62v47pY8KwBwlJj/download -O food_dev_dataset_treat_fourlang.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/3KSW4JUJRcUp5zA/download -O food_dev_dataset_treat_ud.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/EC8qjI6Jo1BTaJ4/download -O food_test_dataset_cause_fourlang.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/LWoP5x2DD0QzM2p/download -O food_test_dataset_cause_ud.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/b8DILcmjJhH7IgP/download -O food_test_dataset_treat_fourlang.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/CDmcKXJlcRv8Wcv/download -O food_test_dataset_treat_ud.tsv diff --git a/features/food/food.ipnyb b/features/food/food.ipynb similarity index 96% rename from features/food/food.ipnyb rename to features/food/food.ipynb index 96d020a..587788f 100644 --- a/features/food/food.ipnyb +++ b/features/food/food.ipynb @@ -209,8 +209,10 @@ "metadata": {}, "outputs": [], "source": [ - "train_df.to_pickle(\"food_train_dataset_treat_ud.pickle\")\n", - "dev_df.to_pickle(\"food_dev_dataset_treat_ud.pickle\")" + "from xpotato.dataset.utils import save_dataframe\n", + "\n", + "save_dataframe(train_df, 'food_train_dataset_treat_ud.tsv')\n", + "save_dataframe(dev_df, 'food_dev_dataset_treat_ud.tsv')" ] }, { @@ -255,8 +257,8 @@ "metadata": {}, "outputs": [], "source": [ - "train_df.to_pickle(\"food_train_dataset_cause_fourlang.pickle\")\n", - "dev_df.to_pickle(\"food_dev_dataset_cause_fourang.pickle\")" + "save_dataframe(train_df, 'food_train_dataset_cause_fourlang.tsv')\n", + "save_dataframe(dev_df, 'food_dev_dataset_cause_fourlang.tsv')" ] }, { diff --git a/features/hasoc/README.md b/features/hasoc/README.md index 8931b91..78b3fed 100644 --- a/features/hasoc/README.md +++ b/features/hasoc/README.md @@ -15,18 +15,18 @@ Prebuilt rule-systems are available in this directory for the _2019, 2020, 2021_ Then the frontend of POTATO can be started from the __frontend__ directory: ```bash -streamlit run app.py -- -t ../features/hasoc/hasoc_2021_train_amr.pickle -v ../features/hasoc/hasoc_2021_val_amr.pickle -hr ../features/hasoc/2021_train_features_task1.json +streamlit run app.py -- -t ../features/hasoc/hasoc_2021_train_amr.tsv -v ../features/hasoc/hasoc_2021_val_amr.tsv -hr ../features/hasoc/2021_train_features_task1.json ``` If you want to reproduce our output run _evaluate.py_ from the _scripts_ directory. ```bash -python evaluate.py -t amr -f ../features/hasoc/2021_train_features_task1.json -d ../features/hasoc/hasoc_2021_test_amr.pickle +python evaluate.py -t amr -f ../features/hasoc/2021_train_features_task1.json -d ../features/hasoc/hasoc_2021_test_amr.tsv ``` If you want to get the classification report, run the script with the __mode__ (-m) parameter: ```bash -python evaluate.py -t amr -f ../features/hasoc/2021_train_features_task1.json -d ../features/hasoc/hasoc_2021_test_amr.pickle -m report +python evaluate.py -t amr -f ../features/hasoc/2021_train_features_task1.json -d ../features/hasoc/hasoc_2021_test_amr.tsv -m report ``` ## Usage and examples on the HASOC data diff --git a/features/hasoc/data.sh b/features/hasoc/data.sh index e828a31..c78dc9b 100644 --- a/features/hasoc/data.sh +++ b/features/hasoc/data.sh @@ -1,9 +1,9 @@ -wget https://owncloud.tuwien.ac.at/index.php/s/VChBRMu2CghoVEB/download -O hasoc_2019_val_amr.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/80ndwqwAnIqkTKt/download -O hasoc_2019_test_amr.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/PtD2aqtuJtzUoH2/download -O hasoc_2019_train_amr.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/gzlHeqNkp95ehLH/download -O hasoc_2020_val_amr.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/RtiiwCjpyJ1pqdu/download -O hasoc_2020_test_amr.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/mngqfVDaTsW7odk/download -O hasoc_2020_train_amr.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/paqXOSj7bbMd5ZI/download -O hasoc_2021_val_amr.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/oocwRTd0XRhgFYd/download -O hasoc_2021_test_amr.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/Khv85ErE6s0cSAc/download -O hasoc_2021_train_amr.pickle \ No newline at end of file +wget https://owncloud.tuwien.ac.at/index.php/s/sUHFGNdvphCUZsQ/download -O hasoc_2019_val_amr.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/PsaHO8N02K9u8sp/download -O hasoc_2019_test_amr.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/QLsQaME33zdT5Xw/download -O hasoc_2019_train_amr.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/Um7BjFu5847yXmd/download -O hasoc_2020_val_amr.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/47HQ9sKo5PmTCTH/download -O hasoc_2020_test_amr.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/hQ56wvpRKxUzVi8/download -O hasoc_2020_train_amr.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/2w8VNtqm7PXTgTX/download -O hasoc_2021_val_amr.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/5Y1V67KMwNMmLC8/download -O hasoc_2021_test_amr.tsv +wget https://owncloud.tuwien.ac.at/index.php/s/rhTbyW1CbfQuWk0/download -O hasoc_2021_train_amr.tsv \ No newline at end of file diff --git a/features/semeval/README.md b/features/semeval/README.md index 0649a85..a55900f 100644 --- a/features/semeval/README.md +++ b/features/semeval/README.md @@ -13,5 +13,5 @@ bash data.sh Then the frontend of POTATO can be started from the __frontend__ directory: ```bash -streamlit run app.py -- -t ../features/semeval/semeval_train.pickle -v ../features/semeval/semeval_val.pickle +streamlit run app.py -- -t ../features/semeval/semeval_train.tsv -v ../features/semeval/semeval_val.tsv ``` \ No newline at end of file diff --git a/features/semeval/data.sh b/features/semeval/data.sh index 8c6a37f..cdb630a 100644 --- a/features/semeval/data.sh +++ b/features/semeval/data.sh @@ -1,4 +1,4 @@ -wget https://owncloud.tuwien.ac.at/index.php/s/6gHDG8XArRuyzDc/download -O semeval_train.pickle +wget https://owncloud.tuwien.ac.at/index.php/s/OgNbqmkUgmcmCTA/download -O semeval_train.tsv wget https://owncloud.tuwien.ac.at/index.php/s/2ESe3bVKiSjZ8jJ/download -O semeval_train.txt wget https://owncloud.tuwien.ac.at/index.php/s/Nx3p4BG9xx7FHVQ/download -O semeval_train_4lang_graphs.pickle -wget https://owncloud.tuwien.ac.at/index.php/s/iX8Fmfsyf6vml6t/download -O semeval_val.pickle +wget https://owncloud.tuwien.ac.at/index.php/s/OgNbqmkUgmcmCTA/download -O semeval_val.tsv diff --git a/frontend/app.py b/frontend/app.py index 4a4d24e..c2668a8 100644 --- a/frontend/app.py +++ b/frontend/app.py @@ -22,11 +22,11 @@ init_extractor, init_session_states, rank_and_suggest, - read_train, - read_val, + read_df, rerun, rule_chooser, save_ruleset, + read_ruleset, save_after_modify, save_dataframe, match_texts, @@ -62,8 +62,7 @@ def inference_mode(evaluator, hand_made_rules): st.session_state.download = st.sidebar.selectbox("", options=[False, True], key=2) if hand_made_rules: - with open(hand_made_rules) as f: - st.session_state.features = json.load(f) + read_ruleset(hand_made_rules) extractor = init_extractor(lang, graph_format) @@ -181,7 +180,7 @@ def inference_mode(evaluator, hand_made_rules): [";".join(feat[1]) for feat in features_merged], [feat[2] for feat in features_merged], ) - save_rules = hand_made_rules or "saved_features.json" + save_rules = hand_made_rules or "saved_features.tsv" save_ruleset(save_rules, st.session_state.features) rerun() @@ -226,8 +225,7 @@ def inference_mode(evaluator, hand_made_rules): def simple_mode(evaluator, data, val_data, graph_format, feature_path, hand_made_rules): if hand_made_rules: - with open(hand_made_rules) as f: - st.session_state.features = json.load(f) + read_ruleset(hand_made_rules) if "df" not in st.session_state: st.session_state.df = data.copy() @@ -634,10 +632,9 @@ def simple_mode(evaluator, data, val_data, graph_format, feature_path, hand_made def advanced_mode(evaluator, train_data, graph_format, feature_path, hand_made_rules): - data = read_train(train_data) + data = read_df(train_data) if hand_made_rules: - with open(hand_made_rules) as f: - st.session_state.features = json.load(f) + read_ruleset(hand_made_rules) if "df" not in st.session_state: st.session_state.df = data.copy() if "annotated" not in st.session_state.df: @@ -1216,9 +1213,9 @@ def main(args): init_session_states() evaluator = init_evaluator() if args.train_data: - data = read_train(args.train_data, args.label) + data = read_df(args.train_data, args.label) if args.val_data: - val_data = read_val(args.val_data, args.label) + val_data = read_df(args.val_data, args.label) graph_format = args.graph_format feature_path = args.suggested_rules hand_made_rules = args.hand_rules diff --git a/frontend/utils.py b/frontend/utils.py index 8b02038..19b9461 100644 --- a/frontend/utils.py +++ b/frontend/utils.py @@ -13,9 +13,11 @@ from streamlit.report_thread import REPORT_CONTEXT_ATTR_NAME from xpotato.dataset.utils import default_pn_to_graph +from xpotato.graph_extractor.graph import PotatoGraph from xpotato.graph_extractor.extract import FeatureEvaluator, GraphExtractor from xpotato.models.trainer import GraphTrainer from xpotato.dataset.utils import default_pn_to_graph +from xpotato.graph_extractor.rule import RuleSet, Rule from tuw_nlp.graph.utils import GraphFormulaMatcher, graph_to_pn from contextlib import contextmanager @@ -156,8 +158,32 @@ def to_dot(graph, marked_nodes=set(), integ=False): def save_ruleset(path, features): - with open(path, "w+") as f: - json.dump(features, f) + rule_set = RuleSet() + rule_set.from_dict(features) + + if path.endswith(".json"): + rule_set.to_json(path) + elif path.endswith(".tsv"): + rule_set.to_tsv(path) + else: + raise ValueError( + "Unknown file extension, currently only .json and .tsv are supported" + ) + + +def read_ruleset(path): + rule_set = RuleSet() + + if path.endswith(".json"): + rule_set.from_json(path) + elif path.endswith(".tsv"): + rule_set.from_tsv(path) + else: + raise ValueError( + "Unknown file extension, currently only .json and .tsv are supported" + ) + + st.session_state.features = rule_set.to_dict() def d_clean(string): @@ -217,7 +243,7 @@ def save_after_modify(hand_made_rules, classes=None): [feat[2] for feat in features_merged], ) - save_rules = hand_made_rules or "saved_features.json" + save_rules = hand_made_rules or "saved_features.tsv" save_ruleset(save_rules, st.session_state.features) st.session_state.rows_to_delete = [] rerun() @@ -229,23 +255,28 @@ def filter_label(df, label): @st.cache(allow_output_mutation=True) -def read_train(path, label=None): - df = pd.read_pickle(path) +def read_df(path, label=None, binary=False): + if binary: + df = pd.read_pickle(path) + else: + df = pd.read_csv(path, sep="\t") + graphs = [] + for graph in df["graph"]: + potato_graph = PotatoGraph(graph_str=graph) + graphs.append(potato_graph.graph) + df["graph"] = graphs if label is not None: filter_label(df, label) return df def save_dataframe(data, path): - data.to_pickle(path) - - -@st.cache(allow_output_mutation=True) -def read_val(path, label=None): - df = pd.read_pickle(path) - if label is not None: - filter_label(df, label) - return df + if ".pickle" in path: + data.to_pickle(path) + else: + graphs = data["graph"] + data["graph"] = [graph_to_pn(graph) for graph in graphs] + data.to_csv(path, sep="\t", index=False) def train_df(df, min_edge=0, rank=False): @@ -354,7 +385,7 @@ def show_ml_feature(classes, hand_made_rules): [";".join(feat[0]) for feat in st.session_state.features[classes]], [";".join(feat[1]) for feat in st.session_state.features[classes]], ) - save_rules = hand_made_rules or "saved_features.json" + save_rules = hand_made_rules or "saved_features.tsv" save_ruleset(save_rules, st.session_state.features) rerun() @@ -526,7 +557,7 @@ def add_rule_manually(classes, hand_made_rules): [";".join(feat[0]) for feat in st.session_state.features[classes]], [";".join(feat[1]) for feat in st.session_state.features[classes]], ) - save_rules = hand_made_rules or "saved_features.json" + save_rules = hand_made_rules or "saved_features.tsv" save_ruleset(save_rules, st.session_state.features) rerun() st.markdown( diff --git a/notebooks/hasoc_examples.ipynb b/notebooks/hasoc_examples.ipynb index 7b6d348..079323a 100644 --- a/notebooks/hasoc_examples.ipynb +++ b/notebooks/hasoc_examples.ipynb @@ -394,28 +394,31 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-12-01 17:19:46 INFO: Loading these models for language: en (English):\n", - "=========================\n", - "| Processor | Package |\n", - "-------------------------\n", - "| tokenize | ewt |\n", - "| pos | ewt |\n", - "| lemma | ewt |\n", - "| depparse | ewt |\n", - "| sentiment | sstplus |\n", - "| ner | ontonotes |\n", - "=========================\n", + "2022-02-14 13:47:31,501 : core (112) - INFO - Loading these models for language: en (English):\n", + "============================\n", + "| Processor | Package |\n", + "----------------------------\n", + "| tokenize | combined |\n", + "| pos | combined |\n", + "| lemma | combined |\n", + "| depparse | combined |\n", + "| sentiment | sstplus |\n", + "| constituency | wsj |\n", + "| ner | ontonotes |\n", + "============================\n", "\n", - "2021-12-01 17:19:46 INFO: Use device: cpu\n", - "2021-12-01 17:19:46 INFO: Loading: tokenize\n", - "2021-12-01 17:19:46 INFO: Loading: pos\n", - "2021-12-01 17:19:48 INFO: Loading: lemma\n", - "2021-12-01 17:19:48 INFO: Loading: depparse\n", - "2021-12-01 17:19:49 INFO: Loading: sentiment\n", - "2021-12-01 17:19:50 INFO: Loading: ner\n", - "2021-12-01 17:19:51 INFO: Done loading processors!\n", - "WARNING:root:creating new NLP cache in en_nlp_cache\n", - "100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [00:07<00:00, 2.00it/s]\n" + "2022-02-14 13:47:31,505 : core (123) - INFO - Use device: cpu\n", + "2022-02-14 13:47:31,506 : core (129) - INFO - Loading: tokenize\n", + "2022-02-14 13:47:31,513 : core (129) - INFO - Loading: pos\n", + "2022-02-14 13:47:31,689 : core (129) - INFO - Loading: lemma\n", + "2022-02-14 13:47:31,718 : core (129) - INFO - Loading: depparse\n", + "2022-02-14 13:47:32,032 : core (129) - INFO - Loading: sentiment\n", + "2022-02-14 13:47:32,275 : core (129) - INFO - Loading: constituency\n", + "2022-02-14 13:47:32,591 : core (129) - INFO - Loading: ner\n", + "2022-02-14 13:47:33,039 : core (179) - INFO - Done loading processors!\n", + "2022-02-14 13:47:33,041 : pipeline (40) - INFO - loading NLP cache from en_nlp_cache...\n", + "2022-02-14 13:47:33,051 : pipeline (42) - INFO - done!\n", + "100%|█████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 2886.78it/s]\n" ] } ], @@ -495,7 +498,7 @@ " RT [USER]: America is the most fucked up count...\n", " HOF\n", " 1\n", - " (1, 0, 2, 3, 4, 5, 7, 6, 8, 9, 10)\n", + " (1, 0, 2, 3, 4, 5, 6, 12, 7, 8, 9, 10, 11, 13,...\n", " \n", " \n", " 3\n", @@ -523,7 +526,7 @@ " Bitch YES [URL]\n", " HOF\n", " 1\n", - " (1, 0, 2, 4, 3, 5)\n", + " (1, 0, 2, 3, 4, 5)\n", " \n", " \n", " 7\n", @@ -537,28 +540,28 @@ " RT [USER]: im not fine, i need you\n", " NOT\n", " 0\n", - " (1, 0, 2, 3, 4, 5, 6, 7, 8)\n", + " (1, 0, 2, 3, 4, 5, 8, 6, 7, 9, 11, 10, 12)\n", " \n", " \n", " 9\n", " Holy shit.. 3 months and I'll be in Italy\n", " HOF\n", " 1\n", - " (1, 2, 0, 3, 4, 5, 6, 11, 7, 8, 9, 10)\n", + " (1, 2, 0, 3, 8, 4, 5, 6, 7)\n", " \n", " \n", " 10\n", " Now I do what I want 🤪\n", " NOT\n", " 0\n", - " (1, 3, 2, 0, 4, 5, 6, 7)\n", + " (1, 4, 2, 3, 0, 5, 6, 7)\n", " \n", " \n", " 11\n", " [USER] you'd immediately stop\n", " NOT\n", " 0\n", - " (1, 2, 0, 3, 4, 7, 5, 6)\n", + " (1, 2, 7, 3, 4, 5, 6, 0)\n", " \n", " \n", " 12\n", @@ -572,7 +575,7 @@ " RT [USER]: ohhhh shit a [USER] [URL]\n", " HOF\n", " 1\n", - " (1, 0, 2, 3, 4, 5, 6, 7, 8, 10, 9, 11)\n", + " (1, 0, 2, 3, 4, 5, 6, 10, 7, 8, 9, 11, 12, 13,...\n", " \n", " \n", " 14\n", @@ -614,18 +617,18 @@ " graph \n", "0 (1, 0, 2, 3, 4, 5, 6) \n", "1 (1, 3, 2, 0, 4, 5, 8, 6, 7, 9, 10, 11, 13, 12,... \n", - "2 (1, 0, 2, 3, 4, 5, 7, 6, 8, 9, 10) \n", + "2 (1, 0, 2, 3, 4, 5, 6, 12, 7, 8, 9, 10, 11, 13,... \n", "3 (1, 4, 2, 3, 0, 5, 7, 6, 8, 10, 9, 11, 12, 13,... \n", "4 (1, 3, 2, 0, 4, 7, 5, 6, 8, 9, 10) \n", "5 (1, 0, 2, 3) \n", - "6 (1, 0, 2, 4, 3, 5) \n", + "6 (1, 0, 2, 3, 4, 5) \n", "7 (1, 2, 0, 3, 4, 5) \n", - "8 (1, 0, 2, 3, 4, 5, 6, 7, 8) \n", - "9 (1, 2, 0, 3, 4, 5, 6, 11, 7, 8, 9, 10) \n", - "10 (1, 3, 2, 0, 4, 5, 6, 7) \n", - "11 (1, 2, 0, 3, 4, 7, 5, 6) \n", + "8 (1, 0, 2, 3, 4, 5, 8, 6, 7, 9, 11, 10, 12) \n", + "9 (1, 2, 0, 3, 8, 4, 5, 6, 7) \n", + "10 (1, 4, 2, 3, 0, 5, 6, 7) \n", + "11 (1, 2, 7, 3, 4, 5, 6, 0) \n", "12 (1, 3, 2, 0, 4, 5, 6) \n", - "13 (1, 0, 2, 3, 4, 5, 6, 7, 8, 10, 9, 11) \n", + "13 (1, 0, 2, 3, 4, 5, 6, 10, 7, 8, 9, 11, 12, 13,... \n", "14 (1, 4, 2, 3, 0, 5, 8, 6, 7, 9) \n", "15 (1, 0, 2) " ] @@ -682,7 +685,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "16it [00:00, 1383.72it/s]\n" + "16it [00:00, 8034.10it/s]\n" ] }, { @@ -874,7 +877,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "16it [00:00, 2464.97it/s]\n" + "16it [00:00, 6796.52it/s]\n" ] }, { @@ -1066,7 +1069,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "16it [00:00, 2086.65it/s]\n" + "16it [00:00, 3718.15it/s]\n" ] }, { @@ -1258,7 +1261,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "16it [00:00, 4242.56it/s]\n" + "16it [00:00, 7336.71it/s]\n" ] }, { @@ -1464,6 +1467,203 @@ { "cell_type": "code", "execution_count": 18, + "id": "f066ebee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
textlabellabel_idgraph
0fuck absolutely everything about today.HOF1(1, 0, 2, 3, 4, 5, 6)
1I just made food and I'm making myself sick to...HOF1(1, 3, 2, 0, 4, 5, 8, 6, 7, 9, 10, 11, 13, 12,...
2RT [USER]: America is the most fucked up count...HOF1(1, 0, 2, 3, 4, 5, 6, 12, 7, 8, 9, 10, 11, 13,...
3you'd be blind to not see the heart eyes i hav...NOT0(1, 4, 2, 3, 0, 5, 7, 6, 8, 10, 9, 11, 12, 13,...
4It's hard for me to give a fuck nowHOF1(1, 3, 2, 0, 4, 7, 5, 6, 8, 9, 10)
5tell me everythingNOT0(1, 0, 2, 3)
6Bitch YES [URL]HOF1(1, 0, 2, 3, 4, 5)
7Eight people a minute....NOT0(1, 2, 0, 3, 4, 5)
8RT [USER]: im not fine, i need youNOT0(1, 0, 2, 3, 4, 5, 8, 6, 7, 9, 11, 10, 12)
9Holy shit.. 3 months and I'll be in ItalyHOF1(1, 2, 0, 3, 8, 4, 5, 6, 7)
10Now I do what I want 🤪NOT0(1, 4, 2, 3, 0, 5, 6, 7)
11[USER] you'd immediately stopNOT0(1, 2, 7, 3, 4, 5, 6, 0)
12Just... shut the fuck upHOF1(1, 3, 2, 0, 4, 5, 6)
13RT [USER]: ohhhh shit a [USER] [URL]HOF1(1, 0, 2, 3, 4, 5, 6, 10, 7, 8, 9, 11, 12, 13,...
14all i want is for yara to survive tonightNOT0(1, 4, 2, 3, 0, 5, 8, 6, 7, 9)
15fuck themHOF1(1, 0, 2)
\n", + "
" + ], + "text/plain": [ + " text label label_id \\\n", + "0 fuck absolutely everything about today. HOF 1 \n", + "1 I just made food and I'm making myself sick to... HOF 1 \n", + "2 RT [USER]: America is the most fucked up count... HOF 1 \n", + "3 you'd be blind to not see the heart eyes i hav... NOT 0 \n", + "4 It's hard for me to give a fuck now HOF 1 \n", + "5 tell me everything NOT 0 \n", + "6 Bitch YES [URL] HOF 1 \n", + "7 Eight people a minute.... NOT 0 \n", + "8 RT [USER]: im not fine, i need you NOT 0 \n", + "9 Holy shit.. 3 months and I'll be in Italy HOF 1 \n", + "10 Now I do what I want 🤪 NOT 0 \n", + "11 [USER] you'd immediately stop NOT 0 \n", + "12 Just... shut the fuck up HOF 1 \n", + "13 RT [USER]: ohhhh shit a [USER] [URL] HOF 1 \n", + "14 all i want is for yara to survive tonight NOT 0 \n", + "15 fuck them HOF 1 \n", + "\n", + " graph \n", + "0 (1, 0, 2, 3, 4, 5, 6) \n", + "1 (1, 3, 2, 0, 4, 5, 8, 6, 7, 9, 10, 11, 13, 12,... \n", + "2 (1, 0, 2, 3, 4, 5, 6, 12, 7, 8, 9, 10, 11, 13,... \n", + "3 (1, 4, 2, 3, 0, 5, 7, 6, 8, 10, 9, 11, 12, 13,... \n", + "4 (1, 3, 2, 0, 4, 7, 5, 6, 8, 9, 10) \n", + "5 (1, 0, 2, 3) \n", + "6 (1, 0, 2, 3, 4, 5) \n", + "7 (1, 2, 0, 3, 4, 5) \n", + "8 (1, 0, 2, 3, 4, 5, 8, 6, 7, 9, 11, 10, 12) \n", + "9 (1, 2, 0, 3, 8, 4, 5, 6, 7) \n", + "10 (1, 4, 2, 3, 0, 5, 6, 7) \n", + "11 (1, 2, 7, 3, 4, 5, 6, 0) \n", + "12 (1, 3, 2, 0, 4, 5, 6) \n", + "13 (1, 0, 2, 3, 4, 5, 6, 10, 7, 8, 9, 11, 12, 13,... \n", + "14 (1, 4, 2, 3, 0, 5, 8, 6, 7, 9) \n", + "15 (1, 0, 2) " + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 19, "id": "50f43ae8", "metadata": {}, "outputs": [], @@ -1475,18 +1675,181 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "id": "a672214c", "metadata": {}, "outputs": [], "source": [ - "train.to_pickle(\"train_dataset\")\n", - "val.to_pickle(\"val_dataset\")" + "from xpotato.dataset.utils import save_dataframe\n", + "\n", + "save_dataframe(train, \"train_dataset.tsv\")\n", + "save_dataframe(val, \"val_dataset.tsv\")" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 21, + "id": "a33b3122", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
textlabellabel_idgraph
14all i want is for yara to survive tonightNOT0(1, 4, 2, 3, 0, 5, 8, 6, 7, 9)
2RT [USER]: America is the most fucked up count...HOF1(1, 0, 2, 3, 4, 5, 6, 12, 7, 8, 9, 10, 11, 13,...
10Now I do what I want 🤪NOT0(1, 4, 2, 3, 0, 5, 6, 7)
7Eight people a minute....NOT0(1, 2, 0, 3, 4, 5)
1I just made food and I'm making myself sick to...HOF1(1, 3, 2, 0, 4, 5, 8, 6, 7, 9, 10, 11, 13, 12,...
9Holy shit.. 3 months and I'll be in ItalyHOF1(1, 2, 0, 3, 8, 4, 5, 6, 7)
8RT [USER]: im not fine, i need youNOT0(1, 0, 2, 3, 4, 5, 8, 6, 7, 9, 11, 10, 12)
4It's hard for me to give a fuck nowHOF1(1, 3, 2, 0, 4, 7, 5, 6, 8, 9, 10)
5tell me everythingNOT0(1, 0, 2, 3)
6Bitch YES [URL]HOF1(1, 0, 2, 3, 4, 5)
3you'd be blind to not see the heart eyes i hav...NOT0(1, 4, 2, 3, 0, 5, 7, 6, 8, 10, 9, 11, 12, 13,...
15fuck themHOF1(1, 0, 2)
\n", + "
" + ], + "text/plain": [ + " text label label_id \\\n", + "14 all i want is for yara to survive tonight NOT 0 \n", + "2 RT [USER]: America is the most fucked up count... HOF 1 \n", + "10 Now I do what I want 🤪 NOT 0 \n", + "7 Eight people a minute.... NOT 0 \n", + "1 I just made food and I'm making myself sick to... HOF 1 \n", + "9 Holy shit.. 3 months and I'll be in Italy HOF 1 \n", + "8 RT [USER]: im not fine, i need you NOT 0 \n", + "4 It's hard for me to give a fuck now HOF 1 \n", + "5 tell me everything NOT 0 \n", + "6 Bitch YES [URL] HOF 1 \n", + "3 you'd be blind to not see the heart eyes i hav... NOT 0 \n", + "15 fuck them HOF 1 \n", + "\n", + " graph \n", + "14 (1, 4, 2, 3, 0, 5, 8, 6, 7, 9) \n", + "2 (1, 0, 2, 3, 4, 5, 6, 12, 7, 8, 9, 10, 11, 13,... \n", + "10 (1, 4, 2, 3, 0, 5, 6, 7) \n", + "7 (1, 2, 0, 3, 4, 5) \n", + "1 (1, 3, 2, 0, 4, 5, 8, 6, 7, 9, 10, 11, 13, 12,... \n", + "9 (1, 2, 0, 3, 8, 4, 5, 6, 7) \n", + "8 (1, 0, 2, 3, 4, 5, 8, 6, 7, 9, 11, 10, 12) \n", + "4 (1, 3, 2, 0, 4, 7, 5, 6, 8, 9, 10) \n", + "5 (1, 0, 2, 3) \n", + "6 (1, 0, 2, 3, 4, 5) \n", + "3 (1, 4, 2, 3, 0, 5, 7, 6, 8, 10, 9, 11, 12, 13,... \n", + "15 (1, 0, 2) " + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train" + ] + }, + { + "cell_type": "code", + "execution_count": 22, "id": "d42d4556", "metadata": { "slideshow": { @@ -1508,7 +1871,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "id": "3a5d67d6", "metadata": { "slideshow": { @@ -1527,7 +1890,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "12it [00:00, 225.73it/s]\n" + "12it [00:00, 213.54it/s]" ] }, { @@ -1540,6 +1903,13 @@ "Training...\n", "Getting features...\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] } ], "source": [ @@ -1548,7 +1918,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 24, "id": "9f628397", "metadata": {}, "outputs": [ @@ -1556,18 +1926,14 @@ "data": { "text/plain": [ "defaultdict(list,\n", - " {'HOF': [(['(u_13 / fuck)'], [], 'HOF'),\n", + " {'HOF': [(['(u_19 / fuck)'], [], 'HOF'),\n", " (['(u_1 / be)'], [], 'HOF'),\n", - " (['(u_17 / url :punct (u_16 / LSB) :punct (u_18 / RSB))'],\n", - " [],\n", - " 'HOF'),\n", - " (['(u_17 / url :punct (u_16 / LSB))'], [], 'HOF'),\n", - " (['(u_16 / LSB)'], [], 'HOF'),\n", - " (['(u_8 / to)'], [], 'HOF'),\n", - " (['(u_21 / I)'], [], 'HOF')]})" + " (['(u_13 / RSB)'], [], 'HOF'),\n", + " (['(u_11 / LSB)'], [], 'HOF'),\n", + " (['(u_8 / to)'], [], 'HOF')]})" ] }, - "execution_count": 22, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -1578,7 +1944,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 25, "id": "25eb919f", "metadata": {}, "outputs": [], diff --git a/notebooks/openie.ipynb b/notebooks/openie.ipynb new file mode 100644 index 0000000..bf781e1 --- /dev/null +++ b/notebooks/openie.ipynb @@ -0,0 +1,399 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "17da8e39", + "metadata": {}, + "source": [ + "## Minimal example of OpenIE" + ] + }, + { + "cell_type": "markdown", + "id": "e8f2d686", + "metadata": {}, + "source": [ + "First define the sentence you want to annotate" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c8c701b1", + "metadata": {}, + "outputs": [], + "source": [ + "sentences = ['This property is defined by ComponentType defined in OPC UA DI.', \n", + " 'The IMachineryItemVendorNameplateType is a subtype of the 2:IVendorNameplateType defined in OPC 10000-100.']" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "7d7cfb0f", + "metadata": {}, + "outputs": [], + "source": [ + "# Import the evaluators\n", + "from xpotato.graph_extractor.extract import GraphExtractor, FeatureEvaluator" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "51a416b1", + "metadata": {}, + "outputs": [], + "source": [ + "evaluator = FeatureEvaluator()\n", + "extractor = GraphExtractor(cache_fn='openie_en')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "dde383b7", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2022-03-08 14:20:00 INFO: Loading these models for language: en (English):\n", + "============================\n", + "| Processor | Package |\n", + "----------------------------\n", + "| tokenize | combined |\n", + "| pos | combined |\n", + "| lemma | combined |\n", + "| depparse | combined |\n", + "| sentiment | sstplus |\n", + "| constituency | wsj |\n", + "| ner | ontonotes |\n", + "============================\n", + "\n", + "2022-03-08 14:20:00,550 : core (112) - INFO - Loading these models for language: en (English):\n", + "============================\n", + "| Processor | Package |\n", + "----------------------------\n", + "| tokenize | combined |\n", + "| pos | combined |\n", + "| lemma | combined |\n", + "| depparse | combined |\n", + "| sentiment | sstplus |\n", + "| constituency | wsj |\n", + "| ner | ontonotes |\n", + "============================\n", + "\n", + "2022-03-08 14:20:00 INFO: Use device: cpu\n", + "2022-03-08 14:20:00,552 : core (123) - INFO - Use device: cpu\n", + "2022-03-08 14:20:00 INFO: Loading: tokenize\n", + "2022-03-08 14:20:00,553 : core (129) - INFO - Loading: tokenize\n", + "2022-03-08 14:20:00 INFO: Loading: pos\n", + "2022-03-08 14:20:00,562 : core (129) - INFO - Loading: pos\n", + "2022-03-08 14:20:00 INFO: Loading: lemma\n", + "2022-03-08 14:20:00,821 : core (129) - INFO - Loading: lemma\n", + "2022-03-08 14:20:01 INFO: Loading: depparse\n", + "2022-03-08 14:20:01,039 : core (129) - INFO - Loading: depparse\n", + "2022-03-08 14:20:01 INFO: Loading: sentiment\n", + "2022-03-08 14:20:01,397 : core (129) - INFO - Loading: sentiment\n", + "2022-03-08 14:20:01 INFO: Loading: constituency\n", + "2022-03-08 14:20:01,807 : core (129) - INFO - Loading: constituency\n", + "2022-03-08 14:20:02 INFO: Loading: ner\n", + "2022-03-08 14:20:02,258 : core (129) - INFO - Loading: ner\n", + "2022-03-08 14:20:02 INFO: Done loading processors!\n", + "2022-03-08 14:20:02,669 : core (179) - INFO - Done loading processors!\n", + "2022-03-08 14:20:02,671 : pipeline (45) - INFO - creating new NLP cache in openie_en\n", + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.09it/s]\n" + ] + } + ], + "source": [ + "# Parse the sentences to graphs\n", + "graphs = list(extractor.parse_iterable(sentences, graph_type='ud', lang='en'))" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "1b121538", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "finite_state_machine\n", + "\n", + "\n", + "\n", + "ComponentType\n", + "\n", + "ComponentType\n", + "\n", + "\n", + "\n", + "by\n", + "\n", + "by\n", + "\n", + "\n", + "\n", + "ComponentType->by\n", + "\n", + "\n", + "case\n", + "\n", + "\n", + "\n", + "define\n", + "\n", + "define\n", + "\n", + "\n", + "\n", + "ComponentType->define\n", + "\n", + "\n", + "acl\n", + "\n", + "\n", + "\n", + "DI\n", + "\n", + "DI\n", + "\n", + "\n", + "\n", + "OPC\n", + "\n", + "OPC\n", + "\n", + "\n", + "\n", + "OPC->DI\n", + "\n", + "\n", + "flat\n", + "\n", + "\n", + "\n", + "UA\n", + "\n", + "UA\n", + "\n", + "\n", + "\n", + "OPC->UA\n", + "\n", + "\n", + "flat\n", + "\n", + "\n", + "\n", + "in\n", + "\n", + "in\n", + "\n", + "\n", + "\n", + "OPC->in\n", + "\n", + "\n", + "case\n", + "\n", + "\n", + "\n", + "PERIOD\n", + "\n", + "PERIOD\n", + "\n", + "\n", + "\n", + "be\n", + "\n", + "be\n", + "\n", + "\n", + "\n", + "define->ComponentType\n", + "\n", + "\n", + "obl\n", + "\n", + "\n", + "\n", + "define->OPC\n", + "\n", + "\n", + "obl\n", + "\n", + "\n", + "\n", + "define->PERIOD\n", + "\n", + "\n", + "punct\n", + "\n", + "\n", + "\n", + "define->be\n", + "\n", + "\n", + "auxCOLONpass\n", + "\n", + "\n", + "\n", + "property\n", + "\n", + "property\n", + "\n", + "\n", + "\n", + "define->property\n", + "\n", + "\n", + "nsubjCOLONpass\n", + "\n", + "\n", + "\n", + "this\n", + "\n", + "this\n", + "\n", + "\n", + "\n", + "property->this\n", + "\n", + "\n", + "det\n", + "\n", + "\n", + "\n", + "root\n", + "\n", + "root\n", + "\n", + "\n", + "\n", + "root->define\n", + "\n", + "\n", + "root\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# We can also check any of the graphs\n", + "from xpotato.models.utils import to_dot\n", + "from graphviz import Source\n", + "\n", + "Source(to_dot(graphs[0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "3112e624", + "metadata": {}, + "outputs": [], + "source": [ + "# Define the featureset\n", + "feature1 = ['(u_0 / define :obl (u_1 / .*) :nsubj.* (u_2 / .*))'], [], 'DEFINED', [{'ARG1': 1, 'ARG2': 2}]\n", + "feature2 = ['(u_0 / subtype :nmod (u_1 / .*) :nsubj (u_2 / .*))'], [], 'SUBTYPE', [{'ARG1': 1, 'ARG2': 2}]" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "847df12f", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a ruleset from the features\n", + "from xpotato.graph_extractor.rule import RuleSet, Rule\n", + "\n", + "rule_set = RuleSet([Rule(feature1, openie=True), Rule(feature2, openie=True)])" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "c7fd16bb", + "metadata": {}, + "outputs": [], + "source": [ + "# Annotate sentence with triplets\n", + "triplets = list(evaluator.annotate(graphs[0], rule_set.to_list()))" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "620e73ac", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'relation': 'DEFINED', 'ARG1': 'ComponentType', 'ARG2': 'property'}]" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "triplets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4bd8ed4b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "base" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/relation_examples.ipynb b/notebooks/relation_examples.ipynb index 3fdb662..933367f 100644 --- a/notebooks/relation_examples.ipynb +++ b/notebooks/relation_examples.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -77,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -87,52 +87,35 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2021-12-03 12:22:11 WARNING: Can not find mwt: default from official model list. Ignoring it.\n", - "WARNING:stanza:Can not find mwt: default from official model list. Ignoring it.\n", - "2021-12-03 12:22:11 INFO: Loading these models for language: en (English):\n", - "=======================\n", - "| Processor | Package |\n", - "-----------------------\n", - "| tokenize | ewt |\n", - "| pos | ewt |\n", - "| lemma | ewt |\n", - "| depparse | ewt |\n", - "=======================\n", + "2022-02-14 12:39:47,351 : common (213) - WARNING - Can not find mwt: default from official model list. Ignoring it.\n", + "2022-02-14 12:39:47,353 : core (112) - INFO - Loading these models for language: en (English):\n", + "========================\n", + "| Processor | Package |\n", + "------------------------\n", + "| tokenize | combined |\n", + "| pos | combined |\n", + "| lemma | combined |\n", + "| depparse | combined |\n", + "========================\n", "\n", - "INFO:stanza:Loading these models for language: en (English):\n", - "=======================\n", - "| Processor | Package |\n", - "-----------------------\n", - "| tokenize | ewt |\n", - "| pos | ewt |\n", - "| lemma | ewt |\n", - "| depparse | ewt |\n", - "=======================\n", - "\n", - "2021-12-03 12:22:11 INFO: Use device: cpu\n", - "INFO:stanza:Use device: cpu\n", - "2021-12-03 12:22:11 INFO: Loading: tokenize\n", - "INFO:stanza:Loading: tokenize\n", - "2021-12-03 12:22:11 INFO: Loading: pos\n", - "INFO:stanza:Loading: pos\n", - "2021-12-03 12:22:12 INFO: Loading: lemma\n", - "INFO:stanza:Loading: lemma\n", - "2021-12-03 12:22:12 INFO: Loading: depparse\n", - "INFO:stanza:Loading: depparse\n", - "2021-12-03 12:22:12 INFO: Done loading processors!\n", - "INFO:stanza:Done loading processors!\n", - "WARNING:root:loading NLP cache from en_nlp_cache...\n", - "WARNING:root:done!\n", - "WARNING:root:loading cache from file: cache/UD_FL.json\n", - "WARNING:root:loaded cache from cache/UD_FL.json with interpretations: ['fl', 'ud']\n", - "100%|██████████| 18/18 [00:00<00:00, 1322.89it/s]\n" + "2022-02-14 12:39:47,376 : core (123) - INFO - Use device: cpu\n", + "2022-02-14 12:39:47,377 : core (129) - INFO - Loading: tokenize\n", + "2022-02-14 12:39:47,382 : core (129) - INFO - Loading: pos\n", + "2022-02-14 12:39:47,537 : core (129) - INFO - Loading: lemma\n", + "2022-02-14 12:39:47,562 : core (129) - INFO - Loading: depparse\n", + "2022-02-14 12:39:47,852 : core (179) - INFO - Done loading processors!\n", + "2022-02-14 12:39:47,853 : pipeline (40) - INFO - loading NLP cache from en_nlp_cache...\n", + "2022-02-14 12:39:47,862 : pipeline (42) - INFO - done!\n", + "2022-02-14 12:39:47,863 : irtg (81) - INFO - loading cache from file: cache/UD_FL.json\n", + "2022-02-14 12:39:47,864 : irtg (21) - INFO - loaded cache from cache/UD_FL.json with interpretations: ['fl', 'ud']\n", + "100%|█████████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 1283.86it/s]\n" ] } ], @@ -143,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -153,251 +136,167 @@ }, { "cell_type": "code", - "execution_count": 130, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
textlabellabel_idgraph
0Governments and industries in nations around t...Entity-Destination(e1,e2)1(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
1The scientists poured XXX into pint YYY.Entity-Destination(e1,e2)1(0, 1, 11, 9, 12, 13)
2The suspect pushed the XXX into a deep YYY.Entity-Destination(e1,e2)1(14, 1, 15, 9, 12, 16)
3The Nepalese government sets up a XXX to inqui...Other0(17, 1, 4, 18, 19, 20, 9, 12, 21, 22, 23, 24)
4The entity1 to buy papers is pushed into the n...Entity-Destination(e1,e2)1(14, 25, 26, 27, 9, 28, 29, 30)
5An unnamed XXX was pushed into the YYY.Entity-Destination(e1,e2)1(14, 1, 31, 9, 12)
6Since then, numerous independent feature XXX h...Other0(32, 1, 33, 34, 35, 9, 12, 36, 37)
7For some reason, the XXX was blinded from his ...Other0(38, 1, 39, 40, 22, 41, 42, 12, 43, 44, 45, 46...
8Sparky Anderson is making progress in his XXX ...Other0(2, 48, 19, 49, 50, 51, 52, 53, 54, 55, 5, 1, ...
9Olympics have already poured one XXX into the ...Entity-Destination(e1,e2)1(0, 1, 58, 59, 60, 9, 12)
10After wrapping him in a light blanket, they pl...Entity-Destination(e1,e2)1(61, 1, 62, 5, 12, 63, 46, 45, 64, 65, 66, 67)
11I placed the XXX in a natural YYY, at the base...Entity-Destination(e1,e2)1(61, 1, 68, 69, 70, 22, 71, 72, 73, 5, 12, 74)
12The XXX was delivered from the YYY of Lincoln ...Other0(75, 76, 77, 71, 22, 78, 79, 80, 81, 44, 45, 8...
13The XXX leaked from every conceivable YYY.Other0(1, 85, 42, 12, 86)
14The scientists placed the XXX in a tiny YYY wh...Entity-Destination(e1,e2)1(61, 1, 11, 5, 12, 2, 87, 37, 88, 89, 90, 91, ...
15The level surface closest to the MSS, known as...Other0(95, 96, 5, 97, 98, 99, 19, 100, 101, 77, 1, 1...
16Gaza XXX recover from three YYY of war.Other0(103, 104, 42, 12, 105, 22, 106)
17This latest XXX from the animation YYY at Pixa...Other0(2, 107, 108, 75, 109, 110, 111, 112, 1, 113, ...
\n", - "
" + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "finite_state_machine\n", + "\n", + "\n", + "\n", + "COORD\n", + "\n", + "COORD\n", + "\n", + "\n", + "\n", + "government\n", + "\n", + "government\n", + "\n", + "\n", + "\n", + "COORD->government\n", + "\n", + "\n", + "0\n", + "\n", + "\n", + "\n", + "industry\n", + "\n", + "industry\n", + "\n", + "\n", + "\n", + "COORD->industry\n", + "\n", + "\n", + "0\n", + "\n", + "\n", + "\n", + "YYY\n", + "\n", + "YYY\n", + "\n", + "\n", + "\n", + "around\n", + "\n", + "around\n", + "\n", + "\n", + "\n", + "nation\n", + "\n", + "nation\n", + "\n", + "\n", + "\n", + "around->nation\n", + "\n", + "\n", + "1\n", + "\n", + "\n", + "\n", + "world\n", + "\n", + "world\n", + "\n", + "\n", + "\n", + "around->world\n", + "\n", + "\n", + "2\n", + "\n", + "\n", + "\n", + "in\n", + "\n", + "in\n", + "\n", + "\n", + "\n", + "in->government\n", + "\n", + "\n", + "1\n", + "\n", + "\n", + "\n", + "in->nation\n", + "\n", + "\n", + "2\n", + "\n", + "\n", + "\n", + "into\n", + "\n", + "into\n", + "\n", + "\n", + "\n", + "into->YYY\n", + "\n", + "\n", + "2\n", + "\n", + "\n", + "\n", + "pour\n", + "\n", + "pour\n", + "\n", + "\n", + "\n", + "into->pour\n", + "\n", + "\n", + "1\n", + "\n", + "\n", + "\n", + "pour->COORD\n", + "\n", + "\n", + "1\n", + "\n", + "\n", + "\n", + "xxx\n", + "\n", + "xxx\n", + "\n", + "\n", + "\n", + "pour->xxx\n", + "\n", + "\n", + "2\n", + "\n", + "\n", + "\n" ], "text/plain": [ - " text \\\n", - "0 Governments and industries in nations around t... \n", - "1 The scientists poured XXX into pint YYY. \n", - "2 The suspect pushed the XXX into a deep YYY. \n", - "3 The Nepalese government sets up a XXX to inqui... \n", - "4 The entity1 to buy papers is pushed into the n... \n", - "5 An unnamed XXX was pushed into the YYY. \n", - "6 Since then, numerous independent feature XXX h... \n", - "7 For some reason, the XXX was blinded from his ... \n", - "8 Sparky Anderson is making progress in his XXX ... \n", - "9 Olympics have already poured one XXX into the ... \n", - "10 After wrapping him in a light blanket, they pl... \n", - "11 I placed the XXX in a natural YYY, at the base... \n", - "12 The XXX was delivered from the YYY of Lincoln ... \n", - "13 The XXX leaked from every conceivable YYY. \n", - "14 The scientists placed the XXX in a tiny YYY wh... \n", - "15 The level surface closest to the MSS, known as... \n", - "16 Gaza XXX recover from three YYY of war. \n", - "17 This latest XXX from the animation YYY at Pixa... \n", - "\n", - " label label_id \\\n", - "0 Entity-Destination(e1,e2) 1 \n", - "1 Entity-Destination(e1,e2) 1 \n", - "2 Entity-Destination(e1,e2) 1 \n", - "3 Other 0 \n", - "4 Entity-Destination(e1,e2) 1 \n", - "5 Entity-Destination(e1,e2) 1 \n", - "6 Other 0 \n", - "7 Other 0 \n", - "8 Other 0 \n", - "9 Entity-Destination(e1,e2) 1 \n", - "10 Entity-Destination(e1,e2) 1 \n", - "11 Entity-Destination(e1,e2) 1 \n", - "12 Other 0 \n", - "13 Other 0 \n", - "14 Entity-Destination(e1,e2) 1 \n", - "15 Other 0 \n", - "16 Other 0 \n", - "17 Other 0 \n", - "\n", - " graph \n", - "0 (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) \n", - "1 (0, 1, 11, 9, 12, 13) \n", - "2 (14, 1, 15, 9, 12, 16) \n", - "3 (17, 1, 4, 18, 19, 20, 9, 12, 21, 22, 23, 24) \n", - "4 (14, 25, 26, 27, 9, 28, 29, 30) \n", - "5 (14, 1, 31, 9, 12) \n", - "6 (32, 1, 33, 34, 35, 9, 12, 36, 37) \n", - "7 (38, 1, 39, 40, 22, 41, 42, 12, 43, 44, 45, 46... \n", - "8 (2, 48, 19, 49, 50, 51, 52, 53, 54, 55, 5, 1, ... \n", - "9 (0, 1, 58, 59, 60, 9, 12) \n", - "10 (61, 1, 62, 5, 12, 63, 46, 45, 64, 65, 66, 67) \n", - "11 (61, 1, 68, 69, 70, 22, 71, 72, 73, 5, 12, 74) \n", - "12 (75, 76, 77, 71, 22, 78, 79, 80, 81, 44, 45, 8... \n", - "13 (1, 85, 42, 12, 86) \n", - "14 (61, 1, 11, 5, 12, 2, 87, 37, 88, 89, 90, 91, ... \n", - "15 (95, 96, 5, 97, 98, 99, 19, 100, 101, 77, 1, 1... \n", - "16 (103, 104, 42, 12, 105, 22, 106) \n", - "17 (2, 107, 108, 75, 109, 110, 111, 112, 1, 113, ... " + "" ] }, - "execution_count": 130, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df" - ] - }, - { - "cell_type": "code", - "execution_count": 103, - "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": "\n\n\n\n\n\nfinite_state_machine\n\n\n\nCOORD\n\nCOORD\n\n\n\ngovernment\n\ngovernment\n\n\n\nCOORD->government\n\n\n0\n\n\n\nindustry\n\nindustry\n\n\n\nCOORD->industry\n\n\n0\n\n\n\nYYY\n\nYYY\n\n\n\naround\n\naround\n\n\n\nnation\n\nnation\n\n\n\naround->nation\n\n\n1\n\n\n\nworld\n\nworld\n\n\n\naround->world\n\n\n2\n\n\n\nin\n\nin\n\n\n\nin->government\n\n\n1\n\n\n\nin->nation\n\n\n2\n\n\n\ninto\n\ninto\n\n\n\ninto->YYY\n\n\n2\n\n\n\npour\n\npour\n\n\n\ninto->pour\n\n\n1\n\n\n\npour->COORD\n\n\n1\n\n\n\nxxx\n\nxxx\n\n\n\npour->xxx\n\n\n2\n\n\n\n", - "text/plain": [ - "" - ] - }, - "execution_count": 103, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -407,7 +306,7 @@ "from xpotato.models.utils import to_dot\n", "from graphviz import Source\n", "\n", - "Source(to_dot(dataset.graphs[0]))" + "Source(to_dot(df.iloc[0].graph))" ] }, { @@ -429,7 +328,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -446,7 +345,7 @@ }, { "cell_type": "code", - "execution_count": 99, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -463,14 +362,14 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "18it [00:00, 5378.85it/s]\n" + "18it [00:00, 6470.47it/s]\n" ] }, { @@ -654,7 +553,7 @@ "17 " ] }, - "execution_count": 134, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -689,14 +588,14 @@ }, { "cell_type": "code", - "execution_count": 139, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "18it [00:00, 4256.50it/s]\n" + "18it [00:00, 4578.38it/s]\n" ] }, { @@ -900,7 +799,7 @@ "17 " ] }, - "execution_count": 139, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -919,14 +818,14 @@ }, { "cell_type": "code", - "execution_count": 142, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "18it [00:00, 2162.69it/s]\n" + "18it [00:00, 4810.90it/s]\n" ] }, { @@ -1130,7 +1029,7 @@ "17 " ] }, - "execution_count": 142, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -1149,14 +1048,14 @@ }, { "cell_type": "code", - "execution_count": 145, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "18it [00:00, 3771.48it/s]\n" + "18it [00:00, 6602.31it/s]\n" ] }, { @@ -1360,7 +1259,7 @@ "17 " ] }, - "execution_count": 145, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -1380,16 +1279,16 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['(u_1 / into :1 (u_2 / push|pour) :2 (u_3 / YYY))']" + "['(u_1 / into :1 (u_2 / pour|push) :2 (u_3 / YYY))']" ] }, - "execution_count": 107, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -1415,7 +1314,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -1426,17 +1325,7 @@ }, { "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [], - "source": [ - "train.to_pickle(\"train_dataset\")\n", - "val.to_pickle(\"val_dataset\")" - ] - }, - { - "cell_type": "code", - "execution_count": 42, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -1453,7 +1342,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -1467,7 +1356,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "14it [00:00, 350.24it/s]" + "14it [00:00, 279.91it/s]" ] }, { @@ -1495,31 +1384,25 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "defaultdict(list,\n", - " {'Entity-Destination(e1,e2)': [(['(u_15 / into :1 (u_26 / push))'],\n", - " [],\n", - " 'Entity-Destination(e1,e2)'),\n", - " (['(u_15 / into :1 (u_19 / pour :2 (u_0 / xxx)))'],\n", + " {'Entity-Destination(e1,e2)': [(['(u_15 / into :1 (u_19 / pour :2 (u_1 / xxx)))'],\n", " [],\n", " 'Entity-Destination(e1,e2)'),\n", " (['(u_15 / into :1 (u_19 / pour))'],\n", " [],\n", " 'Entity-Destination(e1,e2)'),\n", - " (['(u_19 / pour :2 (u_0 / xxx))'],\n", - " [],\n", - " 'Entity-Destination(e1,e2)'),\n", - " (['(u_15 / into :2 (u_3 / yyy))'],\n", + " (['(u_19 / pour :2 (u_1 / xxx))'],\n", " [],\n", " 'Entity-Destination(e1,e2)')]})" ] }, - "execution_count": 44, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -1530,7 +1413,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -1539,7 +1422,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -1570,8 +1453,10 @@ " Support\n", " False_positive_graphs\n", " False_positive_sens\n", + " False_positive_indices\n", " True_positive_graphs\n", " True_positive_sens\n", + " True_positive_indices\n", " False_negative_graphs\n", " False_negative_sens\n", " False_negative_indices\n", @@ -1581,147 +1466,100 @@ " \n", " \n", " 0\n", - " [(u_15 / into :1 (u_26 / push))]\n", - " 1.000000\n", + " [(u_15 / into :1 (u_19 / pour :2 (u_1 / xxx)))]\n", + " 1.0\n", " 0.428571\n", - " 0.600000\n", + " 0.6\n", " 7\n", " []\n", " []\n", - " [(14, 1, 15, 9, 12, 16), (14, 25, 26, 27, 9, 2...\n", - " [(The suspect pushed the XXX into a deep YYY.,...\n", - " [(61, 1, 11, 5, 12, 2, 87, 37, 88, 89, 90, 91,...\n", - " [(The scientists placed the XXX in a tiny YYY ...\n", - " [1]\n", - " [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]\n", - " \n", - " \n", - " 1\n", - " [(u_15 / into :1 (u_19 / pour :2 (u_0 / xxx)))]\n", - " 1.000000\n", - " 0.428571\n", - " 0.600000\n", - " 7\n", - " []\n", " []\n", " [(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 1, 11...\n", " [(Governments and industries in nations around...\n", - " [(61, 1, 11, 5, 12, 2, 87, 37, 88, 89, 90, 91,...\n", + " [2, 6, 7]\n", + " [(2, 86, 35, 87, 88, 89, 61, 1, 11, 5, 12, 90,...\n", " [(The scientists placed the XXX in a tiny YYY ...\n", - " [1]\n", + " [1, 3, 9, 11]\n", " [0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0]\n", " \n", " \n", - " 2\n", + " 1\n", " [(u_15 / into :1 (u_19 / pour))]\n", - " 1.000000\n", + " 1.0\n", " 0.428571\n", - " 0.600000\n", + " 0.6\n", " 7\n", " []\n", " []\n", + " []\n", " [(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 1, 11...\n", " [(Governments and industries in nations around...\n", - " [(61, 1, 11, 5, 12, 2, 87, 37, 88, 89, 90, 91,...\n", + " [2, 6, 7]\n", + " [(2, 86, 35, 87, 88, 89, 61, 1, 11, 5, 12, 90,...\n", " [(The scientists placed the XXX in a tiny YYY ...\n", - " [1]\n", + " [1, 3, 9, 11]\n", " [0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0]\n", " \n", " \n", - " 3\n", - " [(u_19 / pour :2 (u_0 / xxx))]\n", - " 1.000000\n", + " 2\n", + " [(u_19 / pour :2 (u_1 / xxx))]\n", + " 1.0\n", " 0.428571\n", - " 0.600000\n", + " 0.6\n", " 7\n", " []\n", " []\n", + " []\n", " [(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 1, 11...\n", " [(Governments and industries in nations around...\n", - " [(61, 1, 11, 5, 12, 2, 87, 37, 88, 89, 90, 91,...\n", + " [2, 6, 7]\n", + " [(2, 86, 35, 87, 88, 89, 61, 1, 11, 5, 12, 90,...\n", " [(The scientists placed the XXX in a tiny YYY ...\n", - " [1]\n", + " [1, 3, 9, 11]\n", " [0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0]\n", " \n", - " \n", - " 4\n", - " [(u_15 / into :2 (u_3 / yyy))]\n", - " 0.833333\n", - " 0.714286\n", - " 0.769231\n", - " 7\n", - " [(32, 1, 33, 34, 35, 9, 12, 36, 37)]\n", - " [(Since then, numerous independent feature XXX...\n", - " [(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (14, 1, 1...\n", - " [(Governments and industries in nations around...\n", - " [(61, 1, 11, 5, 12, 2, 87, 37, 88, 89, 90, 91,...\n", - " [(The scientists placed the XXX in a tiny YYY ...\n", - " [1]\n", - " [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0]\n", - " \n", " \n", "\n", "" ], "text/plain": [ " Feature Precision Recall \\\n", - "0 [(u_15 / into :1 (u_26 / push))] 1.000000 0.428571 \n", - "1 [(u_15 / into :1 (u_19 / pour :2 (u_0 / xxx)))] 1.000000 0.428571 \n", - "2 [(u_15 / into :1 (u_19 / pour))] 1.000000 0.428571 \n", - "3 [(u_19 / pour :2 (u_0 / xxx))] 1.000000 0.428571 \n", - "4 [(u_15 / into :2 (u_3 / yyy))] 0.833333 0.714286 \n", - "\n", - " Fscore Support False_positive_graphs \\\n", - "0 0.600000 7 [] \n", - "1 0.600000 7 [] \n", - "2 0.600000 7 [] \n", - "3 0.600000 7 [] \n", - "4 0.769231 7 [(32, 1, 33, 34, 35, 9, 12, 36, 37)] \n", + "0 [(u_15 / into :1 (u_19 / pour :2 (u_1 / xxx)))] 1.0 0.428571 \n", + "1 [(u_15 / into :1 (u_19 / pour))] 1.0 0.428571 \n", + "2 [(u_19 / pour :2 (u_1 / xxx))] 1.0 0.428571 \n", "\n", - " False_positive_sens \\\n", - "0 [] \n", - "1 [] \n", - "2 [] \n", - "3 [] \n", - "4 [(Since then, numerous independent feature XXX... \n", + " Fscore Support False_positive_graphs False_positive_sens \\\n", + "0 0.6 7 [] [] \n", + "1 0.6 7 [] [] \n", + "2 0.6 7 [] [] \n", "\n", - " True_positive_graphs \\\n", - "0 [(14, 1, 15, 9, 12, 16), (14, 25, 26, 27, 9, 2... \n", - "1 [(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 1, 11... \n", - "2 [(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 1, 11... \n", - "3 [(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 1, 11... \n", - "4 [(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (14, 1, 1... \n", + " False_positive_indices True_positive_graphs \\\n", + "0 [] [(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 1, 11... \n", + "1 [] [(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 1, 11... \n", + "2 [] [(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 1, 11... \n", "\n", - " True_positive_sens \\\n", - "0 [(The suspect pushed the XXX into a deep YYY.,... \n", - "1 [(Governments and industries in nations around... \n", - "2 [(Governments and industries in nations around... \n", - "3 [(Governments and industries in nations around... \n", - "4 [(Governments and industries in nations around... \n", + " True_positive_sens True_positive_indices \\\n", + "0 [(Governments and industries in nations around... [2, 6, 7] \n", + "1 [(Governments and industries in nations around... [2, 6, 7] \n", + "2 [(Governments and industries in nations around... [2, 6, 7] \n", "\n", " False_negative_graphs \\\n", - "0 [(61, 1, 11, 5, 12, 2, 87, 37, 88, 89, 90, 91,... \n", - "1 [(61, 1, 11, 5, 12, 2, 87, 37, 88, 89, 90, 91,... \n", - "2 [(61, 1, 11, 5, 12, 2, 87, 37, 88, 89, 90, 91,... \n", - "3 [(61, 1, 11, 5, 12, 2, 87, 37, 88, 89, 90, 91,... \n", - "4 [(61, 1, 11, 5, 12, 2, 87, 37, 88, 89, 90, 91,... \n", + "0 [(2, 86, 35, 87, 88, 89, 61, 1, 11, 5, 12, 90,... \n", + "1 [(2, 86, 35, 87, 88, 89, 61, 1, 11, 5, 12, 90,... \n", + "2 [(2, 86, 35, 87, 88, 89, 61, 1, 11, 5, 12, 90,... \n", "\n", " False_negative_sens False_negative_indices \\\n", - "0 [(The scientists placed the XXX in a tiny YYY ... [1] \n", - "1 [(The scientists placed the XXX in a tiny YYY ... [1] \n", - "2 [(The scientists placed the XXX in a tiny YYY ... [1] \n", - "3 [(The scientists placed the XXX in a tiny YYY ... [1] \n", - "4 [(The scientists placed the XXX in a tiny YYY ... [1] \n", + "0 [(The scientists placed the XXX in a tiny YYY ... [1, 3, 9, 11] \n", + "1 [(The scientists placed the XXX in a tiny YYY ... [1, 3, 9, 11] \n", + "2 [(The scientists placed the XXX in a tiny YYY ... [1, 3, 9, 11] \n", "\n", " Predicted \n", - "0 [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0] \n", + "0 [0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0] \n", "1 [0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0] \n", - "2 [0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0] \n", - "3 [0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0] \n", - "4 [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0] " + "2 [0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0] " ] }, - "execution_count": 52, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -1732,7 +1570,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -1751,14 +1589,14 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "8000it [00:00, 189262.97it/s]\n" + "8000it [00:00, 175511.33it/s]\n" ] } ], @@ -1793,7 +1631,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -1806,7 +1644,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -1949,7 +1787,7 @@ "[8000 rows x 4 columns]" ] }, - "execution_count": 110, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -1967,16 +1805,16 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "dataset.load_graphs(\"../features/semeval/semeval_train_4lang_graphs.pickle\")" + "dataset.load_graphs(\"../features/semeval/semeval_train_4lang_graphs.pickle\", binary=True)" ] }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -1985,7 +1823,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -1996,7 +1834,7 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -2007,7 +1845,7 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -2022,7 +1860,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "6399it [00:27, 234.51it/s]\n" + "6399it [00:24, 261.28it/s]\n" ] }, { @@ -2066,20 +1904,20 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(['(u_3 / to :2 (u_2 / entity2))'], [], 'Entity-Destination(e1,e2)'),\n", - " (['(u_15 / into :2 (u_2 / entity2))'], [], 'Entity-Destination(e1,e2)'),\n", + " (['(u_8 / into :2 (u_2 / entity2))'], [], 'Entity-Destination(e1,e2)'),\n", " (['(u_264 / place :2 (u_25 / entity1))'], [], 'Entity-Destination(e1,e2)'),\n", - " (['(u_14 / in :2 (u_2 / entity2))'], [], 'Entity-Destination(e1,e2)'),\n", - " (['(u_1200 / give :2 (u_25 / entity1))'], [], 'Entity-Destination(e1,e2)')]" + " (['(u_19 / in :2 (u_2 / entity2))'], [], 'Entity-Destination(e1,e2)'),\n", + " (['(u_1196 / give :2 (u_25 / entity1))'], [], 'Entity-Destination(e1,e2)')]" ] }, - "execution_count": 76, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -2091,7 +1929,7 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -2101,7 +1939,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -2141,7 +1979,7 @@ " \n", " \n", " 1\n", - " [(u_15 / into :2 (u_2 / entity2))]\n", + " [(u_8 / into :2 (u_2 / entity2))]\n", " 0.762172\n", " 0.628086\n", " 0.688663\n", @@ -2155,14 +1993,14 @@ " \n", " \n", " 3\n", - " [(u_14 / in :2 (u_2 / entity2))]\n", + " [(u_19 / in :2 (u_2 / entity2))]\n", " 0.117526\n", " 0.087963\n", " 0.100618\n", " \n", " \n", " 4\n", - " [(u_1200 / give :2 (u_25 / entity1))]\n", + " [(u_1196 / give :2 (u_25 / entity1))]\n", " 0.533333\n", " 0.012346\n", " 0.024133\n", @@ -2174,13 +2012,13 @@ "text/plain": [ " Feature Precision Recall Fscore\n", "0 [(u_3 / to :2 (u_2 / entity2))] 0.590909 0.200617 0.299539\n", - "1 [(u_15 / into :2 (u_2 / entity2))] 0.762172 0.628086 0.688663\n", + "1 [(u_8 / into :2 (u_2 / entity2))] 0.762172 0.628086 0.688663\n", "2 [(u_264 / place :2 (u_25 / entity1))] 0.791667 0.058642 0.109195\n", - "3 [(u_14 / in :2 (u_2 / entity2))] 0.117526 0.087963 0.100618\n", - "4 [(u_1200 / give :2 (u_25 / entity1))] 0.533333 0.012346 0.024133" + "3 [(u_19 / in :2 (u_2 / entity2))] 0.117526 0.087963 0.100618\n", + "4 [(u_1196 / give :2 (u_25 / entity1))] 0.533333 0.012346 0.024133" ] }, - "execution_count": 81, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -2200,7 +2038,7 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -2209,16 +2047,16 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['(u_1 / into :1 (u_2 / arrive|landing|spread|fly|implant|flow|dump|invest|fetch|release|pour|introduce|leak|remove|add|migrate|insert|stuff|import|pack|transport|misplace) :2 (u_3 / entity2))']" + "['(u_1 / into :1 (u_2 / introduce|migrate|arrive|pour|misplace|import|transport|remove|dump|invest|stuff|flow|fly|add|leak|fetch|release|pack|insert|landing|spread|implant) :2 (u_3 / entity2))']" ] }, - "execution_count": 118, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -2229,7 +2067,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -2238,7 +2076,7 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -2271,10 +2109,10 @@ " \n", " \n", " 0\n", - " [(u_1 / into :1 (u_2 / arrive|landing|spread|f...\n", - " 0.988636\n", - " 0.268519\n", - " 0.42233\n", + " [(u_1 / into :1 (u_2 / introduce|migrate|arriv...\n", + " 0.988701\n", + " 0.270062\n", + " 0.424242\n", " \n", " \n", "\n", @@ -2282,13 +2120,13 @@ ], "text/plain": [ " Feature Precision Recall \\\n", - "0 [(u_1 / into :1 (u_2 / arrive|landing|spread|f... 0.988636 0.268519 \n", + "0 [(u_1 / into :1 (u_2 / introduce|migrate|arriv... 0.988701 0.270062 \n", "\n", - " Fscore \n", - "0 0.42233 " + " Fscore \n", + "0 0.424242 " ] }, - "execution_count": 120, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -2306,13 +2144,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "from xpotato.dataset.utils import save_dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ - "tr_df.to_pickle(\"semeval_train.pickle\")\n", - "val_df.to_pickle(\"semeval_val.pickle\")" + "save_dataframe(tr_df, 'semeval_train.tsv')\n", + "save_dataframe(val_df, 'semeval_val.tsv')" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -2320,7 +2174,7 @@ "hash": "2ba6e79cfe500659b64dde21d0b13217ce6375f8dca9d4d575440e3878ce882b" }, "kernelspec": { - "display_name": "Python 3.9.5 64-bit ('base': conda)", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -2335,8 +2189,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/scripts/convert_pickle.py b/scripts/convert_pickle.py new file mode 100644 index 0000000..de6b85c --- /dev/null +++ b/scripts/convert_pickle.py @@ -0,0 +1,30 @@ +import argparse +import json +import logging +import sys + +from xpotato.dataset.dataset import Dataset +from xpotato.dataset.utils import save_dataframe + + +def get_args(): + parser = argparse.ArgumentParser(description="") + parser.add_argument("-p", "--pickle", type=str, required=True) + parser.add_argument("-o", "--output", type=str, required=True) + return parser.parse_args() + + +def main(): + args = get_args() + + path = args.pickle + output = args.output + + dataset = Dataset(path=path, binary=True) + df = dataset.to_dataframe() + + save_dataframe(df, output) + + +if __name__ == "__main__": + main() diff --git a/scripts/evaluate.py b/scripts/evaluate.py index d4d83c1..29241f6 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -7,6 +7,7 @@ from sklearn.metrics import classification_report from xpotato.graph_extractor.extract import FeatureEvaluator +from xpotato.graph_extractor.graph import PotatoGraph # TODO Adam: This is not the best place for these functions but I didn't want it to be in the frontend.utils @@ -18,8 +19,16 @@ def filter_label(df, label): df["label_id"] = df.apply(lambda x: 0 if x["label"] == "NOT" else 1, axis=1) -def read_val(path, label=None): - df = pd.read_pickle(path) +def read_df(path, label=None, binary=False): + if binary: + df = pd.read_pickle(path) + else: + df = pd.read_csv(path, sep="\t") + graphs = [] + for graph in df["graph"]: + potato_graph = PotatoGraph(graph_str=graph) + graphs.append(potato_graph.graph) + df["graph"] = graphs if label is not None: filter_label(df, label) return df @@ -52,7 +61,7 @@ def main(): ) args = get_args() - df = read_val(args.dataset_path, args.label) + df = read_df(args.dataset_path, args.label) with open(args.features) as f: features = json.load(f) diff --git a/setup.py b/setup.py index 28a2d55..b860251 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="xpotato", - version="0.0.9", + version="0.1.0", description="XAI human-in-the-loop information extraction framework", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", @@ -11,24 +11,20 @@ author_email="adam.kovacs@tuwien.ac.at, gabor.recski@tuwien.ac.at", license="MIT", install_requires=[ - "beautifulsoup4", - "tinydb", - "pandas", + "pandas >= 1.3.5", "tqdm", - "stanza", - "sklearn", - "eli5", - "matplotlib", - "graphviz", - "openpyxl", - "penman", - "networkx >= 2.6.3", - "rank_bm25", + "stanza == 1.3.0", + "scikit-learn == 1.0.2", + "eli5 == 0.11.0", + "graphviz == 0.18.2", + "penman >= 1.2.1", + "networkx == 2.6.3", + "rank_bm25 == 0.2.1", "streamlit == 1.3.1", - "streamlit-aggrid", - "scikit-criteria >= 0.5", - "tuw-nlp", - "amrlib", + "streamlit-aggrid == 0.2.3.post2", + "scikit-criteria == 0.5", + "tuw-nlp >= 0.0.4", + "amrlib == 0.6.0", ], packages=find_packages(), classifiers=[ @@ -39,8 +35,6 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], diff --git a/tests/features.json b/tests/features.json new file mode 100644 index 0000000..7862bc2 --- /dev/null +++ b/tests/features.json @@ -0,0 +1 @@ +{"FIND": [[["(u_0 / find :obl (u_1 / .*) :obj (u_2 / .*))"], [], "FIND"]], "WRITE": [[["(u_0 / write :nsubj (u_1 / .*) :obj (u_2 / .*))"], [], "WRITE"]]} \ No newline at end of file diff --git a/tests/features.tsv b/tests/features.tsv new file mode 100644 index 0000000..4175f6c --- /dev/null +++ b/tests/features.tsv @@ -0,0 +1,2 @@ +FIND (u_0 / find :obl (u_1 / .*) :obj (u_2 / .*)) +WRITE (u_0 / write :nsubj (u_1 / .*) :obj (u_2 / .*)) diff --git a/tests/features_openie.tsv b/tests/features_openie.tsv new file mode 100644 index 0000000..47ccf01 --- /dev/null +++ b/tests/features_openie.tsv @@ -0,0 +1,2 @@ +FIND (u_0 / find :obl (u_1 / .*) :obj (u_2 / .*)) [{"ARG1": 1, "ARG2": 2}] +WRITE (u_0 / write :nsubj (u_1 / .*) :obj (u_2 / .*)) [{"ARG1": 1, "ARG2": 2}] diff --git a/tests/test_ruleset.py b/tests/test_ruleset.py new file mode 100644 index 0000000..cc42850 --- /dev/null +++ b/tests/test_ruleset.py @@ -0,0 +1,102 @@ +from xpotato.dataset.utils import default_pn_to_graph +from xpotato.graph_extractor.extract import FeatureEvaluator +from xpotato.graph_extractor.rule import RuleSet, Rule + +import os + +dir_name = os.path.dirname(os.path.realpath(__file__)) + +FEATURE1 = [["(u_0 / find :obl (u_1 / .*) :obj (u_2 / .*))"], [], "FIND"] +FEATURE2 = [["(u_0 / write :nsubj (u_1 / .*) :obj (u_2 / .*))"], [], "WRITE"] + +FEATURE3 = [ + ["(u_0 / find :obl (u_1 / .*) :obj (u_2 / .*))"], + [], + "FIND", + [{"ARG1": 1, "ARG2": 2}], +] +FEATURE4 = [ + ["(u_0 / write :nsubj (u_1 / .*) :obj (u_2 / .*))"], + [], + "WRITE", + [{"ARG1": 1, "ARG2": 2}], +] + +FEATURE_DICT = { + "FIND": [[["(u_0 / find :obl (u_1 / .*) :obj (u_2 / .*))"], [], "FIND"]], + "WRITE": [[["(u_0 / write :nsubj (u_1 / .*) :obj (u_2 / .*))"], [], "WRITE"]], +} + +GRAPH = "(u_2 / write :nsubj (u_1 / person) :obj (u_4 / sentence :det (u_3 / this)) :conj (u_6 / find :cc (u_5 / and) :obj (u_7 / object) :obl (u_9 / location :case (u_8 / in))) :punct (u_10 / PERIOD) :root-of (u_0 / root))" + + +def test_rule(): + rule1 = Rule(FEATURE1) + + assert rule1.to_list() == FEATURE1 + + +def test_ruleset_to_list(): + rule_set = RuleSet([Rule(FEATURE1)]) + + rule_set.add_rule(Rule(FEATURE2)) + + assert rule_set.to_list() == [FEATURE1, FEATURE2] + + +def test_ruleset_to_dict(): + rule_set = RuleSet([Rule(FEATURE1), Rule(FEATURE2)]) + + assert rule_set.to_dict() == {"FIND": [FEATURE1], "WRITE": [FEATURE2]} + + +def test_ruleset_from_dict_to_list(): + rule_set = RuleSet() + rule_set.from_dict(FEATURE_DICT) + + assert rule_set.to_list() == [FEATURE1, FEATURE2] + + +def test_ruleset_json(): + rule_set = RuleSet() + + rule_set.from_json(os.path.join(dir_name, "features.json")) + + assert rule_set.to_list() == [FEATURE1, FEATURE2] + + +def test_ruleset_to_tsv(): + rule_set = RuleSet([Rule(FEATURE1), Rule(FEATURE2)]) + + rule_set.to_tsv(os.path.join(dir_name, "features.tsv")) + + rule_set = RuleSet() + + rule_set.from_tsv(os.path.join(dir_name, "features.tsv")) + + assert rule_set.to_list() == [FEATURE1, FEATURE2] + + +def test_ruleset_openie(): + rule_set = RuleSet([Rule(FEATURE3, openie=True), Rule(FEATURE4, openie=True)]) + rule_set.to_tsv(os.path.join(dir_name, "features_openie.tsv")) + + rule_set = RuleSet() + rule_set.from_tsv(os.path.join(dir_name, "features_openie.tsv")) + + assert rule_set.to_list() == [FEATURE3, FEATURE4] + + +def test_openie_matching(): + evaluator = FeatureEvaluator() + + G, _ = default_pn_to_graph(GRAPH) + + rule_set = RuleSet([Rule(FEATURE3, openie=True), Rule(FEATURE4, openie=True)]) + + triplets = list(evaluator.annotate(G, rule_set.to_list())) + + assert triplets == [ + {"relation": "FIND", "ARG1": "location", "ARG2": "object"}, + {"relation": "WRITE", "ARG1": "person", "ARG2": "sentence"}, + ] diff --git a/xpotato/dataset/dataset.py b/xpotato/dataset/dataset.py index 249f65f..fd1c103 100644 --- a/xpotato/dataset/dataset.py +++ b/xpotato/dataset/dataset.py @@ -1,24 +1,99 @@ -import pickle +from re import I from typing import List, Tuple, Dict import networkx as nx import pandas as pd +from tqdm import tqdm +from tuw_nlp.graph.utils import graph_to_pn from xpotato.dataset.sample import Sample from xpotato.graph_extractor.extract import GraphExtractor +from xpotato.graph_extractor.graph import PotatoGraph class Dataset: def __init__( - self, examples: List[Tuple[str, str]], label_vocab: Dict[str, int], lang="en" + self, + examples: List[Tuple[str, str]] = None, + label_vocab: Dict[str, int] = {}, + lang="en", + path=None, + binary=False, ) -> None: self.label_vocab = label_vocab - self._dataset = self.read_dataset(examples) + if path: + self._dataset = self.read_dataset(path=path, binary=binary) + else: + self._dataset = self.read_dataset(examples=examples) self.extractor = GraphExtractor(lang=lang, cache_fn=f"{lang}_nlp_cache") self.graphs = None - def read_dataset(self, examples: List[Tuple[str, str]]) -> List[Sample]: - return [Sample(example) for example in examples] + @staticmethod + def save_dataframe(df: pd.DataFrame, path: str) -> None: + graphs = [graph_to_pn(graph) for graph in df["graph"].tolist()] + df["graph"] = graphs + df.to_csv(path, index=False, sep="\t") + + def prune_graphs(self, graphs: List[nx.DiGraph] = None) -> None: + graphs_str = [] + for i, graph in enumerate(graphs): + graph.remove_nodes_from(list(nx.isolates(graph))) + # ADAM: THIS IS JUST FOR PICKLE TO PENMAN CONVERSION + graph = self._random_postprocess(graph) + + g = [ + c + for c in sorted( + nx.weakly_connected_components(graph), key=len, reverse=True + ) + ] + if len(g) > 1: + print( + "WARNING: graph has multiple connected components, taking the largest" + ) + g_pn = graph_to_pn(graph.subgraph(g[0].copy())) + else: + g_pn = graph_to_pn(graph) + + graphs_str.append(g_pn) + + return graphs_str + + def read_dataset( + self, + examples: List[Tuple[str, str]] = None, + path: str = None, + binary: bool = False, + ) -> List[Sample]: + if examples: + return [Sample(example, PotatoGraph()) for example in examples] + elif path: + if binary: + df = pd.read_pickle(path) + graphs_str = self.prune_graphs(df.graph.tolist()) + df.drop(columns=["graph"], inplace=True) + df["graph"] = graphs_str + else: + df = pd.read_csv(path, sep="\t") + + return [ + Sample( + (example["text"], example["label"]), + potato_graph=PotatoGraph(graph_str=example["graph"]), + label_id=example["label_id"], + ) + for _, example in tqdm(df.iterrows()) + ] + else: + raise ValueError("No examples or path provided") + + # ADAM: THIS WILL NEED TO BE ADDRESSED + def _random_postprocess(self, graph: nx.DiGraph) -> nx.DiGraph: + for node, attr in graph.nodes(data=True): + if len(attr["name"].split()) > 1: + attr["name"] = attr["name"].split()[0] + + return graph def to_dataframe(self) -> pd.DataFrame: df = pd.DataFrame( @@ -26,10 +101,9 @@ def to_dataframe(self) -> pd.DataFrame: "text": [sample.text for sample in self._dataset], "label": [sample.label for sample in self._dataset], "label_id": [ - self.label_vocab[sample.label] if sample.label else None - for sample in self._dataset + sample.get_label_id(self.label_vocab) for sample in self._dataset ], - "graph": [sample.graph for sample in self._dataset], + "graph": [sample.potato_graph.graph for sample in self._dataset], } ) return df @@ -41,23 +115,43 @@ def parse_graphs(self, graph_format: str = "fourlang") -> List[nx.DiGraph]: ) ) - self.graphs = graphs - return graphs + self.graphs = [PotatoGraph(graph) for graph in graphs] + return self.graphs - def set_graphs(self, graphs: List[nx.DiGraph]) -> None: - for sample, graph in zip(self._dataset, graphs): - graph.remove_edges_from(nx.selfloop_edges(graph)) - sample.set_graph(graph) + def set_graphs(self, graphs: List[PotatoGraph]) -> None: + for sample, potato_graph in zip(self._dataset, graphs): + potato_graph.graph.remove_edges_from(nx.selfloop_edges(potato_graph.graph)) + sample.set_graph(potato_graph) - def load_graphs(self, path: str) -> None: - PIK = path + def load_graphs(self, path: str, binary: bool = False) -> None: + if binary: + graphs = [graph for graph in pd.read_pickle(path)] + graph_str = self.prune_graphs(graphs) - with open(PIK, "rb") as f: - self.graphs = pickle.load(f) + graphs = [PotatoGraph(graph_str=graph) for graph in graph_str] + self.graphs = graphs + else: + with open(path, "rb") as f: + for line in f: + graph = PotatoGraph(graph_str=line.strip()) + self.graphs.append(graph) self.set_graphs(self.graphs) + def save_dataset(self, path: str) -> None: + df = pd.DataFrame( + { + "text": [sample.text for sample in self._dataset], + "label": [sample.label for sample in self._dataset], + "label_id": [ + sample.get_label_id(self.label_vocab) for sample in self._dataset + ], + "graph": [str(sample.potato_graph) for sample in self._dataset], + } + ) + df.to_csv(path, index=False, sep="\t") + def save_graphs(self, path: str) -> None: - PIK = path - with open(PIK, "wb") as f: - pickle.dump(self.graphs, f) + with open(path, "wb") as f: + for graph in self.graphs: + f.write(str(graph) + "\n") diff --git a/xpotato/dataset/sample.py b/xpotato/dataset/sample.py index a62c171..8f0ed0e 100644 --- a/xpotato/dataset/sample.py +++ b/xpotato/dataset/sample.py @@ -1,13 +1,27 @@ -from typing import Tuple +from typing import Dict, Tuple -import networkx as nx +from xpotato.graph_extractor.graph import PotatoGraph class Sample: - def __init__(self, example: Tuple[str, str]) -> None: + def __init__( + self, + example: Tuple[str, str], + potato_graph: PotatoGraph = None, + label_id: int = None, + ) -> None: self.text = example[0] self.label = example[1] - self.graph = None + self.label_id = label_id + self.potato_graph = potato_graph - def set_graph(self, graph: nx.DiGraph) -> None: - self.graph = graph + def set_graph(self, graph: PotatoGraph) -> None: + self.potato_graph = graph + + def get_label_id(self, label_vocab: Dict[str, int]): + if self.label_id is not None: + return self.label_id + elif self.label and self.label in label_vocab: + return label_vocab[self.label] + else: + return None diff --git a/xpotato/dataset/utils.py b/xpotato/dataset/utils.py index f6c5b94..66fa97e 100644 --- a/xpotato/dataset/utils.py +++ b/xpotato/dataset/utils.py @@ -1,8 +1,17 @@ from collections import defaultdict import networkx as nx +import pandas as pd import penman as pn from tuw_nlp.graph.utils import preprocess_node_alto +from tuw_nlp.graph.utils import graph_to_pn + + +def save_dataframe(df: pd.DataFrame, path: str) -> None: + df_to_save = df.copy() + graphs = [graph_to_pn(graph) for graph in df_to_save["graph"].tolist()] + df_to_save["graph"] = graphs + df_to_save.to_csv(path, index=False, sep="\t") def ud_to_graph(sen, edge_attr="color"): diff --git a/xpotato/graph_extractor/extract.py b/xpotato/graph_extractor/extract.py index e613f99..29a84f5 100644 --- a/xpotato/graph_extractor/extract.py +++ b/xpotato/graph_extractor/extract.py @@ -5,6 +5,7 @@ import networkx as nx import pandas as pd import stanza +import penman as pn from networkx.algorithms.isomorphism import DiGraphMatcher from sklearn.metrics import precision_recall_fscore_support from tqdm import tqdm @@ -40,10 +41,12 @@ def init_nlp(self): nlp = stanza.Pipeline(self.lang) self.nlp = CachedStanzaPipeline(nlp, self.cache_fn) - def parse_iterable(self, iterable, graph_type="fourlang"): + def parse_iterable(self, iterable, graph_type="fourlang", lang=None): + if lang: + self.lang = lang if graph_type == "fourlang": with TextTo4lang( - lang=self.lang, nlp_cache=self.cache_fn, cache_dir=self.cache_dir + lang=self.lang, nlp_cache=self.cache_fn, cache_dir=self.cache_dir ) as tfl: for sen in tqdm(iterable): fl_graphs = list(tfl(sen)) @@ -74,50 +77,121 @@ class FeatureEvaluator: def __init__(self, graph_format="ud"): self.graph_format = graph_format - def match_features(self, dataset, features, multi=False): + # ADAM: Very important to assign IDs to features from 0 because that's how + # the mapping will work!! + def annotate(self, graph, features): + feature_to_marked_nodes = {} + + for i, feature in enumerate(features): + assert ( + len(feature) == 4 + ), f"Feature must be a 4-tuple for OpenIE, not {feature}" + + positive_features = feature[0] + negative_features = feature[1] + + for positive in positive_features: + p = pn.decode(positive) + first = p.triples[0][0] + assert first == "u_0", f"The IDs must start from 0, not {first}" + + for negative in negative_features: + p = pn.decode(negative) + first = p.triples[0][0] + assert first == "u_0", f"The IDs must start from 0, not {first}" + + feature_to_marked_nodes[i] = feature[3] + features[i] = feature[:3] + + matcher = GraphFormulaMatcher(features, converter=default_pn_to_graph) + feats = matcher.match(graph, return_subgraphs=True) + + for key, i, subgraphs in feats: + triplet = {"relation": key} + marked_nodes = feature_to_marked_nodes[i] + for j, node in enumerate(marked_nodes): + subgraph = subgraphs[j] + + node_to_node = {} + for id, graph_node in subgraph.nodes(data=True): + mapping = graph_node["mapping"] + node_to_node[mapping] = graph_node["name"] + + for k, v in node.items(): + triplet[k] = node_to_node[v] + + yield triplet + + def annotate_dataframe(self, dataset, features): + graphs = dataset.graph.tolist() + + triplets = [] + for graph in graphs: + relations = self.annotate(graph, features) + triplets.append(list(relations)) + d = { + "Sentence": dataset.text.tolist(), + "Triplets": triplets, + } + + return pd.DataFrame(d) + + def match_features(self, dataset, features, multi=False, return_subgraphs=False): graphs = dataset.graph.tolist() matches = [] predicted = [] + matched_graphs = [] matcher = GraphFormulaMatcher(features, converter=default_pn_to_graph) for i, g in tqdm(enumerate(graphs)): - feats = matcher.match(g) + feats = matcher.match(g, return_subgraphs=return_subgraphs) if multi: - self.match_multi(feats, features, matches, predicted) + self.match_multi(feats, features, matches, predicted, matched_graphs) else: - self.match_not_multi(feats, features, matches, predicted) + self.match_not_multi( + feats, features, matches, predicted, matched_graphs + ) d = { "Sentence": dataset.text.tolist(), "Predicted label": predicted, "Matched rule": matches, } + if return_subgraphs: + d["Matched subgraph"] = matched_graphs + df = pd.DataFrame(d) return df - def match_multi(self, feats, features, matches, predicted): + def match_multi(self, feats, features, matches, predicted, matched_graphs): keys = [] matched_rules = [] - for key, feature in feats: + matched_subgraphs = [] + for key, feature, graphs in feats: if key not in keys: matched_rules.append(features[feature]) + matched_subgraphs.append(graphs) keys.append(key) if not keys: matches.append("") predicted.append("") + graphs.append("") else: matches.append(matched_rules) predicted.append(keys) + matched_graphs.append(matched_subgraphs) - def match_not_multi(self, feats, features, matches, predicted): - for key, feature in feats: + def match_not_multi(self, feats, features, matches, predicted, matched_graphs): + for key, feature, graphs in feats: matches.append(features[feature]) predicted.append(key) + matched_graphs.append(graphs) break else: matches.append("") + matched_graphs.append("") predicted.append("") def one_versus_rest(self, df, entity): @@ -274,8 +348,8 @@ def select_words(self, trained_features): for word in words_to_measures: if words_to_measures[word]["precision"] > 0.9 and ( - words_to_measures[word]["TP"] > 1 - or words_to_measures[word]["recall"] > 0.01 + words_to_measures[word]["TP"] > 1 + or words_to_measures[word]["recall"] > 0.01 ): selected_words.add(word) @@ -311,7 +385,7 @@ def evaluate_feature(self, cl, features, data, graph_format="ud"): accuracy = [] for pcf in precision_recall_fscore_support( - labels, whole_predicted, average=None + labels, whole_predicted, average=None ): if len(pcf) > 1: accuracy.append(pcf[1]) diff --git a/xpotato/graph_extractor/graph.py b/xpotato/graph_extractor/graph.py new file mode 100644 index 0000000..3613f0d --- /dev/null +++ b/xpotato/graph_extractor/graph.py @@ -0,0 +1,16 @@ +import networkx as nx +from tuw_nlp.graph.utils import graph_to_pn +from xpotato.dataset.utils import default_pn_to_graph + + +class PotatoGraph: + def __init__(self, graph: nx.DiGraph = None, graph_str: str = None) -> None: + if graph: + self.graph = graph + elif graph_str: + self.graph, _ = default_pn_to_graph(graph_str) + else: + self.graph = None + + def __str__(self) -> str: + return graph_to_pn(self.graph) diff --git a/xpotato/graph_extractor/rule.py b/xpotato/graph_extractor/rule.py index 2c13c1b..5a81b47 100644 --- a/xpotato/graph_extractor/rule.py +++ b/xpotato/graph_extractor/rule.py @@ -1,5 +1,126 @@ +import json +from typing import Dict, List, Union + + class Rule: - def __init__(self, rule): + def __init__(self, rule, openie=False): self.positive_samples = rule[0] self.negative_samples = rule[1] self.label = rule[2] + self.openie = openie + + self.marked_nodes = None + if openie: + self.marked_nodes = rule[3] + + def __eq__(self, __o: object) -> bool: + if not isinstance(__o, Rule): + return False + return ( + sorted(self.positive_samples) == sorted(__o.positive_samples) + and sorted(self.negative_samples) == sorted(__o.negative_samples.sort()) + and self.label == __o.label + ) + + def to_list(self) -> List[List[Union[List[str], str]]]: + return ( + [self.positive_samples, self.negative_samples, self.label] + if self.openie == False + else [ + self.positive_samples, + self.negative_samples, + self.label, + self.marked_nodes, + ] + ) + + +class RuleSet: + def __init__(self, rules: List[Rule] = None): + if rules is None: + self.rules = [] + else: + self.rules = rules + + def __iter__(self): + for rule in self.rules: + yield rule + + def __eq__(self, __o: object) -> bool: + if not isinstance(__o, RuleSet): + return False + return self.rules == __o.rules + + def to_tsv(self, tsv_path: str): + with open(tsv_path, "w") as f: + for rule in self.rules: + positive_samples = ";".join(rule.positive_samples) + negative_samples = ";".join(rule.negative_samples) + label = rule.label + marked_nodes = rule.marked_nodes + + rule_str = f"{label}\t{positive_samples}\t{negative_samples}" + + if rule.openie: + rule_str += f"\t{json.dumps(marked_nodes)}" + f.write(rule_str + "\n") + + def from_tsv(self, tsv_path: str): + with open(tsv_path, "r") as f: + for line in f: + line = line.strip("\n") + line = line.split("\t") + + positive_samples = [] if line[1] == "" else line[1].split(";") + negative_samples = [] if line[2] == "" else line[2].split(";") + label = line[0].strip() + rule = None + if len(line) == 3: + rule = Rule( + [positive_samples, negative_samples, label], openie=False + ) + elif len(line) == 4: + marked_nodes = [] if line[3] == "" else json.loads(line[3]) + rule = Rule( + [positive_samples, negative_samples, label, marked_nodes], + openie=True, + ) + else: + raise Exception(f"Invalid number of fields: {line}") + + if not rule: + raise Exception(f"Invalid rule: {line}") + self.add_rule(rule) + + def from_dict( + self, rules: Dict[str, List[List[Union[List[str], str]]]], openie: bool = False + ): + for key, value in rules.items(): + for rule in value: + self.add_rule(Rule(rule, openie=openie)) + + def from_json(self, json_path: str, openie: bool = False): + with open(json_path, "r") as f: + rules = json.load(f) + + for key, value in rules.items(): + for rule in value: + self.add_rule(Rule(rule, openie=openie)) + + def to_json(self, json_path: str): + with open(json_path, "w") as f: + json.dump(self.to_dict(), f) + + def add_rule(self, rule: Rule): + self.rules.append(rule) + + def to_dict(self) -> Dict[str, List[List[Union[List[str], str]]]]: + rule_dict = {rule.label: [] for rule in self.rules} + + for rule in self.rules: + rule_dict[rule.label].append(rule.to_list()) + + return rule_dict + + def to_list(self) -> List[List[Union[List[str], str]]]: + return [rule.to_list() for rule in self.rules]