-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate OpenRouter as optional API, keep OpenAI as default #1
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,36 @@ | ||
import argparse | ||
import json | ||
from openai import OpenAI | ||
from openrouter_client import OpenRouterClient | ||
from prompt_templates import instruction_template, knowledge_template, npc_template, math_template | ||
from datasets import load_dataset | ||
from tqdm import tqdm | ||
|
||
system_prompt = '''You are a helpful assistant.''' | ||
client = OpenAI() # set up your config/env/api for calling openai models | ||
|
||
def get_response(user_prompt): | ||
completion = client.chat.completions.create( | ||
model="gpt-4o", | ||
temperature=0.7, | ||
messages=[ | ||
{"role": "system", "content": f"{system_prompt}"}, | ||
{"role": "user", "content": f"{user_prompt}"} | ||
] | ||
) | ||
return completion.choices[0].message.content | ||
def get_response(user_prompt, use_openrouter=False): | ||
if use_openrouter: | ||
client = OpenRouterClient() | ||
completion = client.chat.create( | ||
model="openai/gpt-4", # Use an appropriate model supported by OpenRouter | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should ask it to implement an extra args to pass down to this function, so that user can do --model="etc" Also, should prob keep |
||
messages=[ | ||
{"role": "system", "content": f"{system_prompt}"}, | ||
{"role": "user", "content": f"{user_prompt}"} | ||
], | ||
temperature=0.7 | ||
) | ||
return completion['choices'][0]['message']['content'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure about the inconsistency here -- I this these 2 should align. BETTER yet, the if else should just be done as a ternary, since the calling interface is the exact same:
|
||
else: | ||
client = OpenAI() | ||
completion = client.chat.completions.create( | ||
model="gpt-4", | ||
messages=[ | ||
{"role": "system", "content": f"{system_prompt}"}, | ||
{"role": "user", "content": f"{user_prompt}"} | ||
], | ||
temperature=0.7 | ||
) | ||
return completion.choices[0].message.content | ||
|
||
def main(args): | ||
# Load the appropriate template | ||
|
@@ -42,26 +55,27 @@ def main(args): | |
for persona in tqdm(persona_dataset['persona']): | ||
persona = persona.strip() | ||
user_prompt = template.format(persona=persona) | ||
gpt4o_out_text = get_response(user_prompt) | ||
o = {"user_prompt": user_prompt, "input persona": persona, "synthesized text": gpt4o_out_text} | ||
response_text = get_response(user_prompt, use_openrouter=args.use_openrouter) | ||
o = {"user_prompt": user_prompt, "input persona": persona, "synthesized text": response_text} | ||
out.write(json.dumps(o, ensure_ascii=False) + '\n') | ||
|
||
print(f"Outputted the results to: {args.output_path}") | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Synthesize text using a specified model and template.") | ||
parser = argparse.ArgumentParser(description="Synthesize text using OpenAI or OpenRouter and a specified template.") | ||
parser.add_argument('--sample_size', type=int, default=0, help='Number of samples to process from the dataset; Set it to 0 if you want to use the full set of 200k personas.') | ||
parser.add_argument( | ||
'--template', | ||
type=str, | ||
required=True, | ||
choices=['instruction', 'knowledge', 'npc', 'math'], | ||
'--template', | ||
type=str, | ||
required=True, | ||
choices=['instruction', 'knowledge', 'npc', 'math'], | ||
help=( | ||
"Prompt templates. Choose from 'instruction', 'knowledge', 'math' or 'npc'. " | ||
"You can also add more customized templates in prompt_templates.py" | ||
) | ||
) | ||
parser.add_argument('--output_path', type=str, required=True, help='Path to the output file.') | ||
parser.add_argument('--use_openrouter', action='store_true', help='Use OpenRouter instead of OpenAI') | ||
|
||
args = parser.parse_args() | ||
main(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import os | ||
from openai import OpenAI | ||
|
||
class OpenRouterClient(OpenAI): | ||
def __init__(self, api_key=None): | ||
super().__init__( | ||
api_key=api_key or os.environ.get("OPENROUTER_API_KEY"), | ||
base_url="https://openrouter.ai/api/v1" | ||
) | ||
self.headers.update({ | ||
"HTTP-Referer": "https://github.com/tencent-ailab/persona-hub", | ||
"X-Title": "Persona Hub" | ||
}) | ||
Comment on lines
+10
to
+13
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. header should also be done within the init, no need to do a separate call |
||
|
||
def chat_create(self, model, messages, temperature=0.7, **kwargs): | ||
return self.chat.completions.create( | ||
model=model, | ||
messages=messages, | ||
temperature=temperature, | ||
headers=self.headers, | ||
**kwargs | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
datasets | ||
transformers | ||
openai | ||
tqdm | ||
tqdm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong code here -- this should be
chat_create()
Better yet, the internal API should just use
chat.completions.create
instead of doing chat_create(). The override doesn't make sense to me