Skip to content

Commit

Permalink
Merge pull request #30 from Jaseci-Labs/semantic_sim_updates
Browse files Browse the repository at this point in the history
Updating functions for semantic similarity
  • Loading branch information
chandralegend authored May 6, 2024
2 parents d4f59c6 + 4184d49 commit 427e2fa
Show file tree
Hide file tree
Showing 13 changed files with 312 additions and 281 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/app_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r app/requirements.txt
- name: Run tests
run: sh scripts/run_tests.sh
run: |
cd app
jac test -f "test_*.jac"
3 changes: 2 additions & 1 deletion app/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ matplotlib

# Auto Evaluator
sentence_transformers
nltk
nltk
tensorflow_hub
79 changes: 58 additions & 21 deletions app/src/components/auto_evaluator/emb_sim_scorer.impl.jac
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ import:py from nltk.translate.bleu_score, sentence_bleu;
import:py from nltk.translate.bleu_score, SmoothingFunction;
import:py from torch, tensor;
import:py from nltk, ngrams;
import:py tensorflow as tf;
import:py from collections, Counter;

:can:generate_embeddings
(anchor_responses_text: list, response_texts: list, embedder: str) {

:can:generate_embeddings(anchor_responses_text: list, response_texts: list, embedder: str) {
anchor_embeddings = [];
response_embeddings = [];
if embedder == "SBERT" {
Expand All @@ -33,14 +35,27 @@ import:py from nltk, ngrams;
} elif embedder == "USE_QA" {
import:py tensorflow_hub as hub;
model = hub.load("https://tfhub.dev/google/universal-sentence-encoder-qa/3");
anchor_embeddings = model.signatures['question_encoder'](tf.constant([anchor_responses_text]))['outputs'];
if not isinstance(anchor_responses_text, list){anchor_responses_text = [anchor_responses_text];}
for i in range(len(anchor_responses_text)){
if not isinstance(anchor_responses_text[i], str){
anchor_responses_text[i] = str(anchor_responses_text[i]);
}
}

if not isinstance(response_texts, list){response_texts = [response_texts];}
for i in range(len(response_texts)){
if not isinstance(response_texts[i], str){
response_texts[i] = str(response_texts[i]);
}
}
anchor_embeddings = model.signatures['question_encoder'](input=tf.constant(anchor_responses_text))['outputs'];
response_embeddings = model.signatures['response_encoder'](input=tf.constant(response_texts), context=tf.constant(response_texts))['outputs'];
}
return (anchor_embeddings, response_embeddings);
}

:can:calculate_similarity_score
(anchor_embeddings: list, response_embeddings: list, scorer: str) {

:can:calculate_similarity_score(anchor_embeddings: list, response_embeddings: list, scorer: str) {
anchor_embeddings = np.array(anchor_embeddings);
response_embeddings = np.array(response_embeddings);
scores = [];
Expand All @@ -57,13 +72,11 @@ import:py from nltk, ngrams;
}
}

:can:display_results
(basedir: str, heatmap_placeholder: st, selected_prompt: str=None) {
:can:display_results(basedir: str, heatmap_placeholder: st, selected_prompt: str=None) {
heat_map(basedir, "A/B Testing", heatmap_placeholder, selected_prompt);
}

:can:process_user_selections
(selected_prompt: str=None) {
:can:process_user_selections (selected_prompt: str=None) {
with open(st.session_state.distribution_file, "r") as fp {
distribution = json.load(fp);
}
Expand Down Expand Up @@ -106,8 +119,7 @@ import:py from nltk, ngrams;
}
}

:can:calculate_embedding_score
(responses: list, anchor_reponses_id: dict, responses_dict: dict) -> None {
:can:calculate_embedding_score(responses: list, anchor_reponses_id: dict, responses_dict: dict) -> None {
anchor_reponses_text = [responses_dict[resp_id] for resp_id in anchor_reponses_id];
response_texts = [responses_dict[resp_id] for resp_id in responses.values()];
if not st.session_state['scorer'] == "sem_bleu" {
Expand All @@ -122,22 +134,47 @@ import:py from nltk, ngrams;
return best_response_idx;
}

:can:embed_sentence
(sentence: str, model: SentenceTransformer) {
:can:embed_sentence(sentence: str, model: SentenceTransformer) {
return model.encode(sentence, convert_to_tensor=True);
}

:can:compute_bleu_score
(reference: str, candidate: str) {
:can:simple_bleu(reference: str, candidate: str, n_gram: int=4) {
reference_tokens = word_tokenize(reference);
candidate_tokens = word_tokenize(candidate);
smoothie = SmoothingFunction().method4;
bleu_score = sentence_bleu(reference_tokens, candidate_tokens, smoothing_function=smoothie);
return bleu_score;
reference_ngrams = [ngrams(reference_tokens, i) for i in range(1, n_gram+1)];
candidate_ngrams = [ngrams(candidate_tokens, i) for i in range(1, n_gram+1)];

weights = np.ones(n_gram) / n_gram;
p_ns = [];

n = min(len(reference_ngrams), len(candidate_ngrams));
i = 0;
while (i < n) {
ref_ng = list(reference_ngrams[i]); # Convert generator to list if necessary
cand_ng = list(candidate_ngrams[i]); # Convert generator to list if necessary
ref_count = Counter(ref_ng);
cand_count = Counter(cand_ng);

count = sum((cand_count & ref_count).values());
total = sum(cand_count.values());

p_n = count / total if total > 0 else 0;
p_ns.append(p_n);
i = i + 1;
}

weights = np.array(weights);
p_ns = np.array(p_ns);
p_ns = np.log(p_ns, out=np.zeros_like(p_ns), where=(p_ns != 0));
bleu = np.exp(np.sum(p_ns * weights));
return bleu;
}

:can:compute_bleu_score(reference: str, candidate: str) {
return simple_bleu(reference, candidate);
}

:can:semantic_bleu_score
(anchor_responses_text: list, response_texts: list, model: SentenceTransformer, ngram_size: int=4, scaling_factor: float=1, bleu_weight: float=0.5) {

:can:semantic_bleu_score(anchor_responses_text: list, response_texts: list, model: SentenceTransformer, ngram_size: int=4, scaling_factor: float=1, bleu_weight: float=0.5) {
scores = [];
for candidate in response_texts {
anchor_score = [];
Expand Down
73 changes: 32 additions & 41 deletions app/src/components/auto_evaluator/emb_sim_scorer.jac
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,55 @@ import:py streamlit as st;
import:py from sentence_transformers, SentenceTransformer;

can generate_embeddings(anchor_responses_text: str, response_texts: list, embedder: str);

can calculate_similarity_score(anchor_embeddings: list, response_embeddings: list, scorer: str);

can display_results(basedir: str, heatmap_placeholder: st, selected_prompt: str=None);

can process_user_selections(selected_prompt: str=None);

can calculate_embedding_score(responses: list, anchor_reponses_id: dict, responses_dict: dict);

can embed_sentence(sentence: str, model: SentenceTransformer);

can compute_bleu_score(reference: list, candidate: list);

can semantic_bleu_score(anchor_responses_text: list, response_texts: list, model: SentenceTransformer, ngram_size: int=4, scaling_factor: float=1, bleu_weight: float=0.5);

glob ANCHOR_MODEL_KEY = 'anchor_model';

glob EMBEDDER_KEY = 'embedder';

glob SCORER_KEY = 'scorer';
can simple_bleu(reference: str, candidate: str, n_gram: int=4);

can emb_sim_scorer {
if ANCHOR_MODEL_KEY not in st.session_state {
st.session_state[ANCHOR_MODEL_KEY] = 'gpt-4';
if 'anchor_model' not in st.session_state {
st.session_state['anchor_model'] = 'gpt-4';
}
if EMBEDDER_KEY not in st.session_state {
st.session_state[EMBEDDER_KEY] = 'SBERT';
if 'embedder' not in st.session_state {
st.session_state['embedder'] = 'SBERT';
}
if SCORER_KEY not in st.session_state {
st.session_state[SCORER_KEY] = 'cos_sim';
if 'scorer' not in st.session_state {
st.session_state['scorer'] = 'cos_sim';
}
if st.session_state.get("current_hv_config", None) {
if 'button_clicked' not in st.session_state {
st.session_state.button_clicked = False;
}
if st.session_state.button_clicked {
if "selected_prompt" in st.session_state {
process_user_selections(st.session_state["selected_prompt"]);
}
st.session_state.button_clicked = False;
}
button_clicked = st.session_state.get('button_clicked', False);
model_list = st.session_state.active_list_of_models;
if st.session_state[ANCHOR_MODEL_KEY] not in model_list {
st.session_state[ANCHOR_MODEL_KEY] = model_list[0];
}
(col1, col2, col3) = st.columns(3);
with col1 {
anchor_model_selection = st.selectbox("Select Anchor Model", options=model_list, key=ANCHOR_MODEL_KEY, index=model_list.index(st.session_state[ANCHOR_MODEL_KEY]));
}
with col2 {
embedder_selection = st.selectbox("Select Type of Embedder", options=['USE', 'USE_QA', 'SBERT'], key=EMBEDDER_KEY, index=['USE', 'USE_QA', 'SBERT', 'OPEN_AI_Embedder'].index(st.session_state[EMBEDDER_KEY]));
if st.session_state['anchor_model'] not in model_list {
st.session_state['anchor_model'] = model_list[0];
}
with col3 {
scorer_selection = st.selectbox("Select Scorer", options=['cos_sim', 'sem_bleu'], key=SCORER_KEY, index=['cos_sim', 'sem_bleu'].index(st.session_state[SCORER_KEY]));

if st.session_state['anchor_model'] not in model_list {
st.session_state['anchor_model'] = model_list[0];
}

(col1, col2, col3) = st.columns(3);
anchor_model_selection = col1.selectbox("Select Anchor Model", options=model_list, key='anchor_model', index=model_list.index(st.session_state.get('anchor_model', model_list[0])));
embedder_selection = col2.selectbox("Select Type of Embedder", options=['USE', 'USE_QA', 'SBERT'], key='embedder', index=['USE', 'USE_QA', 'SBERT', 'OPEN_AI_Embedder'].index(st.session_state.get('embedder', 'SBERT')));
scorer_selection = col3.selectbox("Select Scorer", options=['cos_sim', 'sem_bleu'], key='scorer', index=['cos_sim', 'sem_bleu'].index(st.session_state.get('scorer', 'cos_sim')));

if st.button('Calculate Embedding Scores') {
st.session_state.button_clicked = True;
process_user_selections();
try {
with st.spinner('Calculating embedding scores... Please wait.'){
process_user_selections();
st.session_state['button_clicked'] = True;
}
st.success('Finished calculating embedding scores!');
} except Exception as e{
print(e);
st.error('Error calculating embedding scores. Please try again.');
}
}
if button_clicked {
st.session_state['button_clicked'] = False;
}
} else {
st.error("Human Evaluation config was not found. Initialize a Human Evaluation first.");
Expand Down
2 changes: 1 addition & 1 deletion app/src/components/dashboard/dashboard.impl.jac
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import:jac from plot_utils, generate_stacked_bar_chart, generate_heatmaps;
st.session_state.workers_data_dir = os.path.abspath("results");
st.session_state.distribution_file = os.path.abspath(os.path.join(".human_eval_config", "distribution.json"));
st.session_state.response_file = os.path.abspath(os.path.join(".human_eval_config", "responses.json"));
st.session_state.prompt_data_dir = os.path.abspath("data");
st.session_state.prompt_data_dir = os.path.abspath("data"); #TODO: Uses to get the run name, Fix is to include that in the prompt info file
st.session_state.prompt_info_file = os.path.abspath(os.path.join(".human_eval_config", "prompt_info.json"));
st.session_state.models_responses = os.path.abspath(os.path.join(".human_eval_config", "models_responses.json"));
with open(st.session_state.models_responses, "r") as f {
Expand Down
6 changes: 3 additions & 3 deletions app/src/components/setup/setup.impl.jac
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ can add_data_sources {
st.subheader("Human Evaluation Configuration");
(hv_config_1_col, hv_config_2_col, hv_config_3_col) = st.columns(3);
with hv_config_1_col {
n_workers = st.number_input("Number of workers", min_value=10, step=1, value=st.session_state.config["config"]["n_workers"], help="Number of Evaluators going to participate");
n_questions_per_worker = st.number_input("Number of questions per worker", min_value=2, max_value=100, step=1, value=st.session_state.config["config"]["n_questions_per_worker"], help="Number of questions shown to an Evaluator");
n_workers = st.number_input("Number of Evaluators", min_value=10, step=1, value=st.session_state.config["config"]["n_workers"], help="Number of Evaluators going to participate");
n_questions_per_worker = st.number_input("Number of questions per evaluator", min_value=2, max_value=100, step=1, value=st.session_state.config["config"]["n_questions_per_worker"], help="Number of questions shown to an Evaluator");
show_captcha = st.checkbox("Show Captcha (Human Verification)", value=st.session_state.config["config"]["show_captcha"]);
ability_to_tie = st.selectbox("Ability to Choose Both", ["Allow", "Not Allowed"], index=["Allow", "Not Allowed"].index(st.session_state.config["config"]["ability_to_tie"]), help="Select whether the evaluator can choose both options as the same.");
evenly_distributed = st.checkbox("Usecases are Evenly distributed among the workers", value=st.session_state.config["config"]["evenly_distributed"], help="If checked, the usecases will be evenly distributed among the workers. for example, if there are 2 usecases and 10 workers, each worker will get 1 question from each usecase. If not checked, the questions will be randomly distributed.");
evenly_distributed = st.checkbox("Usecases are Evenly distributed among the evaluators", value=st.session_state.config["config"]["evenly_distributed"], help="If checked, the usecases will be evenly distributed among the workers. for example, if there are 2 usecases and 10 workers, each worker will get 1 question from each usecase. If not checked, the questions will be randomly distributed.");
}
with hv_config_2_col {
json_files = [f for f in os.listdir("data") if f.endswith(".json")] if os.path.exists("data") else [];
Expand Down
Loading

0 comments on commit 427e2fa

Please sign in to comment.