Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary stuff #25

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,9 @@
DEFAULT_UNK_TOKEN = "</s>"
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
Expand All @@ -58,9 +54,9 @@ class DataArguments:
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
optim: str = field(default="adamw_bnb_8bit")
model_max_length: int = field(
default=512,
default=2048
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)

Expand Down
117 changes: 0 additions & 117 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,123 +13,6 @@
from openai import openai_object
import copy

StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]

openai_org = os.getenv("OPENAI_ORG")
if openai_org is not None:
openai.organization = openai_org
logging.warning(f"Switching to organization: {openai_org} for OAI API key.")


@dataclasses.dataclass
class OpenAIDecodingArguments(object):
max_tokens: int = 1800
temperature: float = 0.2
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Optional[Sequence[str]] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
suffix: Optional[str] = None
logprobs: Optional[int] = None
echo: bool = False


def openai_completion(
prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],
decoding_args: OpenAIDecodingArguments,
model_name="text-davinci-003",
sleep_time=2,
batch_size=1,
max_instances=sys.maxsize,
max_batches=sys.maxsize,
return_text=False,
**decoding_kwargs,
) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:
"""Decode with OpenAI API.

Args:
prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
it can also be a dictionary (or list thereof) as explained here:
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
decoding_args: Decoding arguments.
model_name: Model name. Can be either in the format of "org/model" or just "model".
sleep_time: Time to sleep once the rate-limit is hit.
batch_size: Number of prompts to send in a single request. Only for non chat model.
max_instances: Maximum number of prompts to decode.
max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
return_text: If True, return text instead of full completion object (which contains things like logprob).
decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.

Returns:
A completion or a list of completions.
Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
- a string (if return_text is True)
- an openai_object.OpenAIObject object (if return_text is False)
- a list of objects of the above types (if decoding_args.n > 1)
"""
is_single_prompt = isinstance(prompts, (str, dict))
if is_single_prompt:
prompts = [prompts]

if max_batches < sys.maxsize:
logging.warning(
"`max_batches` will be deprecated in the future, please use `max_instances` instead."
"Setting `max_instances` to `max_batches * batch_size` for now."
)
max_instances = max_batches * batch_size

prompts = prompts[:max_instances]
num_prompts = len(prompts)
prompt_batches = [
prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
for batch_id in range(int(math.ceil(num_prompts / batch_size)))
]

completions = []
for batch_id, prompt_batch in tqdm.tqdm(
enumerate(prompt_batches),
desc="prompt_batches",
total=len(prompt_batches),
):
batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args

while True:
try:
shared_kwargs = dict(
model=model_name,
**batch_decoding_args.__dict__,
**decoding_kwargs,
)
completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)
choices = completion_batch.choices

for choice in choices:
choice["total_tokens"] = completion_batch.usage.total_tokens
completions.extend(choices)
break
except openai.error.OpenAIError as e:
logging.warning(f"OpenAIError: {e}.")
if "Please reduce your prompt" in str(e):
batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)
logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")
else:
logging.warning("Hit request rate limit; retrying...")
time.sleep(sleep_time) # Annoying rate limit on requests.

if return_text:
completions = [completion.text for completion in completions]
if decoding_args.n > 1:
# make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.
completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]
if is_single_prompt:
# Return non-tuple if only 1 input and 1 generation.
(completions,) = completions
return completions


def _make_w_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f_dirname = os.path.dirname(f)
Expand Down