Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tanya-borisova committed Mar 15, 2024
1 parent f5dabff commit 7a52865
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 93 deletions.
126 changes: 39 additions & 87 deletions rag_experiment_accelerator/ingest_data/tests/test_acs_ingest.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import json
import uuid

from unittest.mock import patch, Mock
from unittest.mock import patch, Mock, ANY

from rag_experiment_accelerator.ingest_data.acs_ingest import (
generate_title,
my_hash,
generate_summary,
upload_data,
generate_qna,
we_need_multiple_questions,
do_we_need_multiple_questions,
)

from rag_experiment_accelerator.llm.prompts import (
prompt_instruction_title,
prompt_instruction_summary,
multiple_prompt_instruction,
)

Expand Down Expand Up @@ -62,95 +58,48 @@ def test_my_hash_with_numbers():
assert result == expected_hash


@patch("rag_experiment_accelerator.ingest_data.acs_ingest.ResponseGenerator")
def test_generate_title(mock_response_generator):
# Arrange
mock_response = "Test Title"
mock_chunk = "This is a test chunk of text."
mock_deployment_name = "TestDeployment"
mock_response_generator().generate_response.return_value = mock_response

# Act
result = generate_title(mock_chunk, mock_deployment_name)

# Assert
mock_response_generator().generate_response.assert_called_once_with(
prompt_instruction_title, mock_chunk
)
assert result == mock_response


@patch("rag_experiment_accelerator.ingest_data.acs_ingest.ResponseGenerator")
def test_generate_summary(mock_response_generator):
# Arrange
mock_summary = "Test Summary"
mock_chunk = "This is a test chunk of text."
mock_deployment_name = "TestDeployment"
mock_response_generator().generate_response.return_value = mock_summary

# Act
result = generate_summary(mock_chunk, mock_deployment_name)

# Assert
mock_response_generator().generate_response.assert_called_once_with(
prompt_instruction_summary, mock_chunk
)
assert result == mock_summary


@patch("rag_experiment_accelerator.ingest_data.acs_ingest.SearchClient")
@patch("rag_experiment_accelerator.ingest_data.acs_ingest.AzureKeyCredential")
@patch("rag_experiment_accelerator.ingest_data.acs_ingest.generate_title")
@patch("rag_experiment_accelerator.ingest_data.acs_ingest.generate_summary")
@patch("rag_experiment_accelerator.ingest_data.acs_ingest.my_hash")
@patch("rag_experiment_accelerator.ingest_data.acs_ingest.pre_process.preprocess")
@patch("rag_experiment_accelerator.ingest_data.acs_ingest.ResponseGenerator")
def test_upload_data(
mock_response_generator,
mock_preprocess,
mock_my_hash,
mock_generate_summary,
mock_generate_title,
mock_AzureKeyCredential,
mock_azure_key_credential,
mock_SearchClient,
):
# Arrange
mock_chunks = [{"content": "test content", "content_vector": "test_vector"}]
mock_service_endpoint = "test_endpoint"
mock_index_name = "test_index"
mock_search_key = "test_key"
mock_embedding_model = Mock()
mock_azure_oai_deployment_name = "test_deployment"
mock_my_hash.return_value = "test_hash"
mock_generate_title.return_value = "test_title"
mock_generate_summary.return_value = "test_summary"
mock_preprocess.return_value = "test_preprocessed_content"
mock_AzureKeyCredential.return_value = "test_credential"
mock_environment = Mock()
mock_environment.azure_search_service_endpoint = "test_endpoint"
mock_environment.azure_search_admin_key = "test_key"
mock_config = Mock()
mock_response_generator.return_value.generate_response.return_value = "test_text"

# Act
upload_data(
mock_environment,
mock_config,
mock_chunks,
mock_service_endpoint,
mock_index_name,
mock_search_key,
"test_index",
mock_embedding_model,
mock_azure_oai_deployment_name,
)

# Assert
mock_AzureKeyCredential.assert_called_once_with(mock_search_key)
mock_azure_key_credential.assert_called_once_with(mock_search_key)
mock_SearchClient.assert_called_once_with(
endpoint=mock_service_endpoint,
index_name=mock_index_name,
credential="test_credential",
endpoint="test_endpoint",
index_name="test_index",
credential=ANY,
)
mock_my_hash.assert_called_once_with(mock_chunks[0]["content"])
mock_generate_title.assert_called_once_with(
str(mock_chunks[0]["content"]), mock_azure_oai_deployment_name
)
mock_generate_summary.assert_called_once_with(
str(mock_chunks[0]["content"]), mock_azure_oai_deployment_name
)
mock_preprocess.assert_any_call("test_summary")
mock_preprocess.assert_any_call("test_title")
mock_preprocess.assert_any_call("test_text")
mock_embedding_model.generate_embedding.assert_any_call(
chunk="test_preprocessed_content"
)
Expand All @@ -177,62 +126,65 @@ def test_generate_qna_with_invalid_json(mock_response_generator, mock_json_loads
mock_json_loads.side_effect = json.JSONDecodeError("Invalid JSON", doc="", pos=0)

# Act
result = generate_qna(mock_docs, mock_deployment_name)
result = generate_qna(Mock(), Mock(), mock_docs, mock_deployment_name)

# Assert
assert len(result) == 0
mock_json_loads.assert_called_once_with(mock_response)


@patch("rag_experiment_accelerator.ingest_data.acs_ingest.ResponseGenerator")
@patch(
"rag_experiment_accelerator.ingest_data.acs_ingest.ResponseGenerator",
return_value=Mock(),
)
def test_we_need_multiple_questions(mock_response_generator):
# Arrange
question = "What is the meaning of life?"
azure_oai_deployment_name = "TestDeployment"
mock_response = "The meaning of life is 42."
mock_response_generator().generate_response.return_value = mock_response
mock_response_generator.generate_response.return_value = mock_response
expected_prompt_instruction = (
multiple_prompt_instruction + "\n" + "question: " + question + "\n"
)

# Act
result = we_need_multiple_questions(question, azure_oai_deployment_name)
result = we_need_multiple_questions(question, mock_response_generator)

# Assert
mock_response_generator.assert_called_with(
deployment_name=azure_oai_deployment_name
)
mock_response_generator().generate_response.assert_called_once_with(
mock_response_generator.generate_response.assert_called_once_with(
expected_prompt_instruction, ""
)
assert result == mock_response


@patch("rag_experiment_accelerator.ingest_data.acs_ingest.ResponseGenerator")
@patch(
"rag_experiment_accelerator.ingest_data.acs_ingest.ResponseGenerator",
return_value=Mock(),
)
def test_do_we_need_multiple_questions_true(mock_response_generator):
# Arrange
question = "What is the meaning of life?"
azure_oai_deployment_name = "TestDeployment"
mock_response_generator().generate_response.return_value = '{"category": "complex"}'
mock_response_generator.generate_response.return_value = '{"category": "complex"}'

# Act
result = do_we_need_multiple_questions(question, azure_oai_deployment_name)
result = do_we_need_multiple_questions(question, mock_response_generator)

# Assert
mock_response_generator().generate_response.assert_called_once()
mock_response_generator.generate_response.assert_called_once()
assert result is True


@patch("rag_experiment_accelerator.ingest_data.acs_ingest.ResponseGenerator")
@patch(
"rag_experiment_accelerator.ingest_data.acs_ingest.ResponseGenerator",
return_value=Mock(),
)
def test_do_we_need_multiple_questions_false(mock_response_generator):
# Arrange
question = "What is the meaning of life?"
azure_oai_deployment_name = "TestDeployment"
mock_response_generator().generate_response.return_value = '{"category": ""}'
mock_response_generator.generate_response.return_value = '{"category": ""}'

# Act
result = do_we_need_multiple_questions(question, azure_oai_deployment_name)
result = do_we_need_multiple_questions(question, mock_response_generator)

# Assert
mock_response_generator().generate_response.assert_called_once()
mock_response_generator.generate_response.assert_called_once()
assert result is False
13 changes: 7 additions & 6 deletions rag_experiment_accelerator/nlp/language_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from azure.ai.textanalytics import TextAnalyticsClient
from azure.core.credentials import AzureKeyCredential

from rag_experiment_accelerator.config import AzureSkillsCredentials
from rag_experiment_accelerator.utils.logging import get_logger
from rag_experiment_accelerator.config.environment import Environment

logger = get_logger(__name__)

Expand Down Expand Up @@ -40,6 +40,7 @@ class LanguageEvaluator:

def __init__(
self,
environment: Environment,
query_language="en-us",
default_language="en",
country_hint="",
Expand All @@ -54,8 +55,8 @@ def __init__(
country_hint if country_hint else query_language.split("-")[1]
)
self.confidence_threshold = confidence_threshold
self.creds = AzureSkillsCredentials.from_env_or_keyvault()
self.max_content_length = 50000 # Data limit
self.environment = environment
except Exception as e:
logger.error(str(e))

Expand All @@ -73,11 +74,11 @@ def check_string(self, input_string):

def detect_language(self, text: str):
try:
service_endpoint = self.creds.AZURE_LANGUAGE_SERVICE_ENDPOINT
key = self.creds.AZURE_LANGUAGE_SERVICE_KEY

client = TextAnalyticsClient(
endpoint=service_endpoint, credential=AzureKeyCredential(key)
endpoint=self.environment.azure_language_service_endpoint,
credential=AzureKeyCredential(
self.environment.azure_language_service_key
),
)
response = client.detect_language(documents=[text])

Expand Down

0 comments on commit 7a52865

Please sign in to comment.