Skip to content

Commit

Permalink
type skip_codes, retry_codes as lists; satisfy reasonable linter requ…
Browse files Browse the repository at this point in the history
…ests
  • Loading branch information
leondz committed Nov 12, 2024
1 parent 2b5f109 commit 2983ac9
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions garak/generators/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class RestGenerator(Generator):
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"headers": {},
"method": "post",
"ratelimit_codes": {429},
"skip_codes": set(),
"ratelimit_codes": [429],
"skip_codes": [],
"response_json": False,
"response_json_field": None,
"req_template": "$INPUT",
Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(self, uri=None, config_root=_config):
try:
self.json_expr = jsonpath_ng.parse(self.response_json_field)
except JsonPathParserError as e:
logging.CRITICAL(
logging.critical(
"Couldn't parse response_json_field %s", self.response_json_field
)
raise e
Expand Down Expand Up @@ -198,37 +198,41 @@ def _call_model(

if resp.status_code in self.skip_codes:
logging.debug(
f"REST skip prompt: {resp.status_code} - {resp.reason}, uri: {self.uri}"
"REST skip prompt: %s - %s, uri: %s",
resp.status_code,
resp.reason,
self.uri,
)
return [None]

elif resp.status_code in self.ratelimit_codes:
if resp.status_code in self.ratelimit_codes:
raise RateLimitHit(
f"Rate limited: {resp.status_code} - {resp.reason}, uri: {self.uri}"
)

elif str(resp.status_code)[0] == "3":
if str(resp.status_code)[0] == "3":
raise NotImplementedError(
f"REST URI redirection: {resp.status_code} - {resp.reason}, uri: {self.uri}"
)

elif str(resp.status_code)[0] == "4":
if str(resp.status_code)[0] == "4":
raise ConnectionError(
f"REST URI client error: {resp.status_code} - {resp.reason}, uri: {self.uri}"
)

elif str(resp.status_code)[0] == "5":
if str(resp.status_code)[0] == "5":
error_msg = f"REST URI server error: {resp.status_code} - {resp.reason}, uri: {self.uri}"
if self.retry_5xx:
raise IOError(error_msg)
else:
raise ConnectionError(error_msg)
raise ConnectionError(error_msg)

if not self.response_json:
return [str(resp.text)]

response_object = json.loads(resp.content)

response = [None] * generations_this_call

# if response_json_field starts with a $, treat is as a JSONPath
assert (
self.response_json
Expand Down

0 comments on commit 2983ac9

Please sign in to comment.