diff --git a/ai_scientist/generate_ideas.py b/ai_scientist/generate_ideas.py index d5f4f97f..48ba1cd7 100644 --- a/ai_scientist/generate_ideas.py +++ b/ai_scientist/generate_ideas.py @@ -7,7 +7,7 @@ import backoff import requests -from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS +from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS, MAX_JSON_RETRIES S2_API_KEY = os.getenv("S2_API_KEY") @@ -112,56 +112,65 @@ def generate_ideas( for _ in range(max_num_generations): print() print(f"Generating idea {_ + 1}/{max_num_generations}") - try: - prev_ideas_string = "\n\n".join(idea_str_archive) - - msg_history = [] - print(f"Iteration 1/{num_reflections}") - text, msg_history = get_response_from_llm( - idea_first_prompt.format( - task_description=prompt["task_description"], - code=code, - prev_ideas_string=prev_ideas_string, - num_reflections=num_reflections, - ), - client=client, - model=model, - system_message=idea_system_prompt, - msg_history=msg_history, - ) - ## PARSE OUTPUT - json_output = extract_json_between_markers(text) - assert json_output is not None, "Failed to extract JSON from LLM output" - print(json_output) - - # Iteratively improve task. - if num_reflections > 1: - for j in range(num_reflections - 1): - print(f"Iteration {j + 2}/{num_reflections}") - text, msg_history = get_response_from_llm( - idea_reflection_prompt.format( - current_round=j + 2, num_reflections=num_reflections - ), - client=client, - model=model, - system_message=idea_system_prompt, - msg_history=msg_history, - ) - ## PARSE OUTPUT - json_output = extract_json_between_markers(text) - assert ( - json_output is not None - ), "Failed to extract JSON from LLM output" - print(json_output) - - if "I am done" in text: - print(f"Idea generation converged after {j + 2} iterations.") - break - - idea_str_archive.append(json.dumps(json_output)) - except Exception as e: - print(f"Failed to generate idea: {e}") - continue + retry_count = 0 + while retry_count < MAX_JSON_RETRIES: + try: + prev_ideas_string = "\n\n".join(idea_str_archive) + + msg_history = [] + print(f"Iteration 1/{num_reflections} (Attempt {retry_count + 1}/{MAX_JSON_RETRIES})") + text, msg_history = get_response_from_llm( + idea_first_prompt.format( + task_description=prompt["task_description"], + code=code, + prev_ideas_string=prev_ideas_string, + num_reflections=num_reflections, + ), + client=client, + model=model, + system_message=idea_system_prompt, + msg_history=msg_history, + ) + ## PARSE OUTPUT + json_output = extract_json_between_markers(text) + if json_output is None: + retry_count += 1 + continue + print(json_output) + + # Iteratively improve task. + if num_reflections > 1: + for j in range(num_reflections - 1): + print(f"Iteration {j + 2}/{num_reflections}") + text, msg_history = get_response_from_llm( + idea_reflection_prompt.format( + current_round=j + 2, num_reflections=num_reflections + ), + client=client, + model=model, + system_message=idea_system_prompt, + msg_history=msg_history, + ) + ## PARSE OUTPUT + json_output = extract_json_between_markers(text) + if json_output is None: + retry_count += 1 + continue + print(json_output) + + if "I am done" in text: + print(f"Idea generation converged after {j + 2} iterations.") + break + + idea_str_archive.append(json.dumps(json_output)) + break + except Exception as e: + print(f"Failed to generate idea: {e}") + retry_count += 1 + if retry_count >= MAX_JSON_RETRIES: + print(f"Max retries ({MAX_JSON_RETRIES}) reached, skipping idea") + break + continue ## SAVE IDEAS ideas = [] diff --git a/ai_scientist/llm.py b/ai_scientist/llm.py index 7811fb92..21871f3f 100644 --- a/ai_scientist/llm.py +++ b/ai_scientist/llm.py @@ -7,6 +7,7 @@ import openai MAX_NUM_TOKENS = 4096 +MAX_JSON_RETRIES = 3 AVAILABLE_LLMS = [ "claude-3-5-sonnet-20240620", @@ -272,7 +273,10 @@ def extract_json_between_markers(llm_output): try: parsed_json = json.loads(json_string) return parsed_json - except json.JSONDecodeError: + except json.JSONDecodeError as e: + # Provide detailed error message + error_msg = f"JSON parse error: {str(e)}\nContent: {json_string[:100]}..." + print(error_msg) # Attempt to fix common JSON issues try: # Remove invalid control characters @@ -282,7 +286,8 @@ def extract_json_between_markers(llm_output): except json.JSONDecodeError: continue # Try next match - return None # No valid JSON found + print("No valid JSON found in LLM output") + return None def create_client(model): diff --git a/ai_scientist/perform_writeup.py b/ai_scientist/perform_writeup.py index 7dc9eebe..4dc5426d 100644 --- a/ai_scientist/perform_writeup.py +++ b/ai_scientist/perform_writeup.py @@ -8,7 +8,7 @@ from typing import Optional, Tuple from ai_scientist.generate_ideas import search_for_papers -from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS +from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS, MAX_JSON_RETRIES # GENERATE LATEX @@ -312,8 +312,14 @@ def get_citation_aider_prompt( return None, True ## PARSE OUTPUT - json_output = extract_json_between_markers(text) - assert json_output is not None, "Failed to extract JSON from LLM output" + retry_count = 0 + while retry_count < MAX_JSON_RETRIES: + json_output = extract_json_between_markers(text) + if json_output is not None: + break + retry_count += 1 + if retry_count >= MAX_JSON_RETRIES: + raise ValueError("Failed to extract JSON after max retries") query = json_output["Query"] papers = search_for_papers(query) except Exception as e: @@ -354,8 +360,14 @@ def get_citation_aider_prompt( print("Do not add any.") return None, False ## PARSE OUTPUT - json_output = extract_json_between_markers(text) - assert json_output is not None, "Failed to extract JSON from LLM output" + retry_count = 0 + while retry_count < MAX_JSON_RETRIES: + json_output = extract_json_between_markers(text) + if json_output is not None: + break + retry_count += 1 + if retry_count >= MAX_JSON_RETRIES: + raise ValueError("Failed to extract JSON after max retries") desc = json_output["Description"] selected_papers = json_output["Selected"] selected_papers = str(selected_papers)