-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
64 lines (53 loc) · 3.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import logging
from cross_align.data_loader import load_fasttext_model, get_top_n_words, load_bilingual_lexicon
from cross_align.alignement import align_embeddings, apply_alignment, iterative_procrustes_alignment
from cross_align.evaluation import word_translation_accuracy, analyze_cosine_similarities, ablation_study, plot_ablation_results, plot_similarity_distribution
import numpy as np
import argparse
def setup_logging():
log_format = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO,
format=log_format,
handlers=[
logging.FileHandler("logs/alignment.log"),
logging.StreamHandler()
])
def main(args):
setup_logging()
logging.info("Starting cross-lingual alignment pipeline...")
embedding_dir = "./embeddings/"
muse_dir = "lexicon/"
model_type = "trained" if args.trained else "pre-trained"
logging.info(f"Loading {model_type} FastText models...")
en_embeddings = load_fasttext_model(embedding_dir, 'en', trained=args.trained)
hi_embeddings = load_fasttext_model(embedding_dir, 'hi', trained=args.trained)
logging.info("Extracting top 100000 words from FastText models...")
en_words = get_top_n_words(en_embeddings)
hi_words = get_top_n_words(hi_embeddings)
logging.info("Loading bilingual lexicon...")
train_dict = load_bilingual_lexicon(muse_dir, 'en', 'hi')
test_dict = load_bilingual_lexicon(muse_dir, 'en', 'hi', train=False)
logging.info("Performing supervised alignment...")
alignment_matrix = align_embeddings(en_embeddings, hi_embeddings, en_words, hi_words, train_dict)
logging.info("Initial alignment matrix computed.")
en_aligned_supervised = apply_alignment(en_embeddings, alignment_matrix)
logging.info("Evaluating supervised alignment...")
p1, p5 = word_translation_accuracy(en_aligned_supervised, hi_embeddings, en_words, hi_words, test_dict)
logging.info(f"Supervised Alignment Results for {model_type} model: Precision@1: {p1:.4f}, Precision@5: {p5:.4f}")
logging.info("Analyzing cosine similarities...")
word_pairs = load_bilingual_lexicon(muse_dir, 'en', 'hi', max_pairs=1000, train=False)
similarities = analyze_cosine_similarities(en_aligned_supervised, hi_embeddings, en_words, hi_words, word_pairs)
sizes = [5000, 10000, 20000]
logging.info("Starting ablation study with sizes: 5000, 10000, 20000")
ablation_results = ablation_study(en_embeddings, hi_embeddings, en_words, hi_words, train_dict, test_dict, sizes)
for size, p1, p5 in ablation_results:
logging.info(f"For sizes {size}: P@1 = {p1:.4f}, P@5 = {p5:.4f}")
logging.info("Plotting ablation study results...")
plot_ablation_results(ablation_results, model_type)
plot_similarity_distribution(similarities, model_type)
logging.info("Ablation study completed and plotted.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Cross-lingual embedding alignment pipeline.")
parser.add_argument("--trained", action="store_true", help="Use trained FastText models instead of pre-trained models.")
args = parser.parse_args()
main(args)