Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
sumadhva30 committed Jun 6, 2024
1 parent 074f81b commit bec91e0
Show file tree
Hide file tree
Showing 20 changed files with 56 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""__init__.py."""
"""__init__.py."""
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# flake8: noqa: F401

"""__init__.py."""

from .batch_api_client import BatchApiClient
Expand All @@ -12,4 +14,4 @@
Message,
SystemMessage,
UserMessage
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from batch_api.data_validation_result import DataValidationResult

BATCH_DATA_VALIDATION_RESULT_SUBMISSION_URL = "https://managed-batch-inference" # TODO: Update the URL
BATCH_DATA_VALIDATION_RESULT_SUBMISSION_URL = "https://managed-batch-inference" # TODO: Update the URL

logger = getLogger(__name__)

Expand Down Expand Up @@ -37,4 +37,4 @@ def submit_validation_result(self, data_validation_result: DataValidationResult)
json=data_validation_result
)

logger.info(f"MBI service response: Status Code: '{response.status_code}', Response: '{response.text}'")
logger.info(f"MBI service response: Status Code: '{response.status_code}', Response: '{response.text}'")
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ class BatchReference:
aoai_account_name: str
batch_id: str
resource_group_name: str
subscription_id: str
subscription_id: str
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ class DataValidationResult:
errors: List[BatchValidationError]

def __init__(self):
self.errors = []
self.errors = []
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# flake8: noqa: F401

"""__init__.py."""

from .request_body import RequestBody
Expand All @@ -9,4 +11,4 @@
Message,
SystemMessage,
UserMessage
)
)
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# flake8: noqa: F401

"""__init__.py."""

from .chat_completion_request_body import ChatCompletionRequestBody
from .message import (
Message,
SystemMessage,
UserMessage
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
class ChatCompletionRequestBody(RequestBody):
"""Request body for chat completion"""

messages: List[Union[SystemMessage, UserMessage]]
messages: List[Union[SystemMessage, UserMessage]]
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ class SystemMessage(Message):
class UserMessage(Message):
"""UserMessage"""

content: Any
content: Any
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ class RequestBody:
"""Request Body"""

model: str
user: Optional[str]
user: Optional[str]
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# flake8: noqa: F401

"""__init__.py."""

from .row_validation_context import RowValidationContext
Expand All @@ -10,4 +12,4 @@
from .base_validator import BaseValidator
from .json_validator import JsonValidator
from .schema_validator import SchemaValidator
from .common_property_validator import CommonPropertyValidator
from .common_property_validator import CommonPropertyValidator
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ class BaseValidator(ABC):
@abstractmethod
def validate_row(self, row_context: RowValidationContext) -> RowValidationResult:
"""Validate the row"""
pass
pass
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ class InputRow:
custom_id: str
method: str
url: str
body: Union[RequestBody, ChatCompletionRequestBody]
body: Union[RequestBody, ChatCompletionRequestBody]
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ def validate_row(self, row_context: RowValidationContext) -> RowValidationResult
line=row_context.line_number
)

return result
return result
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ class RowValidationContext:
def __init__(self, raw_input_row: str, line_number: Optional[int] = None):
self.raw_input_row = raw_input_row
self.parsed_input_row = None
self.line_number = line_number
self.line_number = line_number
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ class RowValidationResult:
is_success: bool = error is None

def __init__(self) -> None:
self.error = None
self.error = None
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def validate_row(self, row_context: RowValidationContext) -> RowValidationResult
input_row_dict = json.loads(row_context.raw_input_row)
row_context.parsed_input_row = self.get_valid_input_row(input_row_dict)

except:
except Exception:
result.error = BatchValidationError(
code=AoaiBatchValidationErrorCode.INVALID_REQUEST,
message=BatchValidationErrorMessage.INVALID_REQUEST,
Expand Down Expand Up @@ -86,4 +86,4 @@ def get_valid_message(self, data: dict) -> Union[SystemMessage, UserMessage]:
)

else:
raise ValueError("Invalid message role")
raise ValueError("Invalid message role")
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""__init__.py."""
"""__init__.py."""
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,14 @@ class BatchValidationError:
line: Optional[int] = None
param: Optional[str] = None

def __init__(self, code: AoaiBatchValidationErrorCode, message: BatchValidationErrorMessage, line: Optional[int] = None, param: Optional[str] = None):
def __init__(
self,
code: AoaiBatchValidationErrorCode,
message: BatchValidationErrorMessage,
line: Optional[int] = None,
param: Optional[str] = None
):
self.code = code
self.message = message
self.line = line
self.param = param
self.param = param
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,18 @@
import logging
from typing import List

from batch_api import BatchApiClient, DataValidationResult
from row_validators import *
from batch_api import (
# BatchApiClient, # TODO: Uncomment this line to import BatchApiClient
DataValidationResult
)
from row_validators import (
BaseValidator,
JsonValidator,
SchemaValidator,
CommonPropertyValidator,
RowValidationContext,
RowValidationResult
)
from utils.exceptions import (
AoaiBatchValidationErrorCode,
BatchValidationErrorMessage,
Expand All @@ -26,7 +36,7 @@ def parse_args():
"--input_data_file",
type=str,
help="Path to the input file",
required=True
required=True
)

return parser.parse_args()
Expand Down Expand Up @@ -54,7 +64,7 @@ def main():

row_validators = get_row_validators()

batch_api_client = BatchApiClient()
# batch_api_client = BatchApiClient() # TODO: Uncomment this line to create an instance of BatchApiClient
data_validation_result = DataValidationResult()

if not input_data:
Expand All @@ -76,12 +86,13 @@ def main():
line_number=i
)

for validator in row_validators:
for validator in row_validators:
row_validation_result: RowValidationResult = validator.validate_row(row_validation_context)

if not row_validation_result.is_success:
logger.error(
f"Validation failed for input row '{i}'. Error code: '{row_validation_result.error.code}'. " +
f"Validation failed for input row '{i}'. " +
f"Error code: '{row_validation_result.error.code}'. " +
f"Error message: '{row_validation_result.error.message}'"
)

Expand All @@ -106,10 +117,11 @@ def main():
finally:
logger.info("Submitting validation result to Batch API.")

# batch_api_client.submit_validation_result(data_validation_result) # TODO: Uncomment this line to submit the validation result to Batch API
# TODO: Uncomment this line to submit the validation result to Batch API
# batch_api_client.submit_validation_result(data_validation_result)

logger.info("Validation result submitted to Batch API. Exiting.")


if __name__ == "__main__":
main()
main()

0 comments on commit bec91e0

Please sign in to comment.