Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add retry limits and improve error handling for JSON extraction #4

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 60 additions & 51 deletions ai_scientist/generate_ideas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import backoff
import requests

from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS
from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS, MAX_JSON_RETRIES

S2_API_KEY = os.getenv("S2_API_KEY")

Expand Down Expand Up @@ -112,56 +112,65 @@ def generate_ideas(
for _ in range(max_num_generations):
print()
print(f"Generating idea {_ + 1}/{max_num_generations}")
try:
prev_ideas_string = "\n\n".join(idea_str_archive)

msg_history = []
print(f"Iteration 1/{num_reflections}")
text, msg_history = get_response_from_llm(
idea_first_prompt.format(
task_description=prompt["task_description"],
code=code,
prev_ideas_string=prev_ideas_string,
num_reflections=num_reflections,
),
client=client,
model=model,
system_message=idea_system_prompt,
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"
print(json_output)

# Iteratively improve task.
if num_reflections > 1:
for j in range(num_reflections - 1):
print(f"Iteration {j + 2}/{num_reflections}")
text, msg_history = get_response_from_llm(
idea_reflection_prompt.format(
current_round=j + 2, num_reflections=num_reflections
),
client=client,
model=model,
system_message=idea_system_prompt,
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert (
json_output is not None
), "Failed to extract JSON from LLM output"
print(json_output)

if "I am done" in text:
print(f"Idea generation converged after {j + 2} iterations.")
break

idea_str_archive.append(json.dumps(json_output))
except Exception as e:
print(f"Failed to generate idea: {e}")
continue
retry_count = 0
while retry_count < MAX_JSON_RETRIES:
try:
prev_ideas_string = "\n\n".join(idea_str_archive)

msg_history = []
print(f"Iteration 1/{num_reflections} (Attempt {retry_count + 1}/{MAX_JSON_RETRIES})")
text, msg_history = get_response_from_llm(
idea_first_prompt.format(
task_description=prompt["task_description"],
code=code,
prev_ideas_string=prev_ideas_string,
num_reflections=num_reflections,
),
client=client,
model=model,
system_message=idea_system_prompt,
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
if json_output is None:
retry_count += 1
continue
print(json_output)

# Iteratively improve task.
if num_reflections > 1:
for j in range(num_reflections - 1):
print(f"Iteration {j + 2}/{num_reflections}")
text, msg_history = get_response_from_llm(
idea_reflection_prompt.format(
current_round=j + 2, num_reflections=num_reflections
),
client=client,
model=model,
system_message=idea_system_prompt,
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
if json_output is None:
retry_count += 1
continue
print(json_output)

if "I am done" in text:
print(f"Idea generation converged after {j + 2} iterations.")
break

idea_str_archive.append(json.dumps(json_output))
break
except Exception as e:
print(f"Failed to generate idea: {e}")
retry_count += 1
if retry_count >= MAX_JSON_RETRIES:
print(f"Max retries ({MAX_JSON_RETRIES}) reached, skipping idea")
break
continue

## SAVE IDEAS
ideas = []
Expand Down
9 changes: 7 additions & 2 deletions ai_scientist/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import openai

MAX_NUM_TOKENS = 4096
MAX_JSON_RETRIES = 3

AVAILABLE_LLMS = [
"claude-3-5-sonnet-20240620",
Expand Down Expand Up @@ -272,7 +273,10 @@ def extract_json_between_markers(llm_output):
try:
parsed_json = json.loads(json_string)
return parsed_json
except json.JSONDecodeError:
except json.JSONDecodeError as e:
# Provide detailed error message
error_msg = f"JSON parse error: {str(e)}\nContent: {json_string[:100]}..."
print(error_msg)
# Attempt to fix common JSON issues
try:
# Remove invalid control characters
Expand All @@ -282,7 +286,8 @@ def extract_json_between_markers(llm_output):
except json.JSONDecodeError:
continue # Try next match

return None # No valid JSON found
print("No valid JSON found in LLM output")
return None


def create_client(model):
Expand Down
22 changes: 17 additions & 5 deletions ai_scientist/perform_writeup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Optional, Tuple

from ai_scientist.generate_ideas import search_for_papers
from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS
from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS, MAX_JSON_RETRIES


# GENERATE LATEX
Expand Down Expand Up @@ -312,8 +312,14 @@ def get_citation_aider_prompt(
return None, True

## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"
retry_count = 0
while retry_count < MAX_JSON_RETRIES:
json_output = extract_json_between_markers(text)
if json_output is not None:
break
retry_count += 1
if retry_count >= MAX_JSON_RETRIES:
raise ValueError("Failed to extract JSON after max retries")
query = json_output["Query"]
papers = search_for_papers(query)
except Exception as e:
Expand Down Expand Up @@ -354,8 +360,14 @@ def get_citation_aider_prompt(
print("Do not add any.")
return None, False
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"
retry_count = 0
while retry_count < MAX_JSON_RETRIES:
json_output = extract_json_between_markers(text)
if json_output is not None:
break
retry_count += 1
if retry_count >= MAX_JSON_RETRIES:
raise ValueError("Failed to extract JSON after max retries")
desc = json_output["Description"]
selected_papers = json_output["Selected"]
selected_papers = str(selected_papers)
Expand Down