Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic Beams #36

Merged
merged 1 commit into from
Oct 18, 2023
Merged

Dynamic Beams #36

merged 1 commit into from
Oct 18, 2023

Conversation

wongjingping
Copy link
Collaborator

  • Use dynamic number of beams depending on the prompt's token length. We scale it down approximately quadratically due to the quadratic nature of attention. We now no longer need the statements to explicitly deal with torch memory before the generate statement.
  • Update prompt to be the same as sql-coder.
  • Add tests.

Verified that the latest fix is able to work with the following command:

$ python main.py \
  -q data/questions_gen.csv \
  -o "results/defog_sqlcoder_npl_short_stratified2_best.csv" \
  -g hf \
  -f prompts/prompt.md \
  -m /home/defog/finetuning/starcoder/defog_sqlcoder_npl_short_stratified2/best
...
Correct so far: 145/200 (72.50%): 100%|█████████████████████████████████████████████████| 200/200 [08:49<00:00,  2.65s/it]
                exact_match   correct
query_category                       
date_functions     0.600000  0.600000
group_by           0.857143  0.885714
order_by           0.685714  0.800000
ratio              0.514286  0.600000
table_join         0.771429  0.800000
where              0.600000  0.628571

@wongjingping wongjingping requested a review from rishsriv October 18, 2023 06:32
@rishsriv rishsriv force-pushed the jp/dynamic_num_beams branch 3 times, most recently from 035297b to f4ff7cd Compare October 18, 2023 07:41
def dynamic_num_beams(prompt: str, tokenizer) -> int:
tokens = len(tokenizer.encode(prompt))
print(tokens)
if tokens <= 1024:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this!

Should we still let users set something like "max beams"? The idea here is to a) let them choose a lower numbers of beams if they want (say, for latency) and b) to let them choose a higher number of beams if they want (say, for higher accuracy – and if they have access to high memory GPUs)

I'm thinking of something like this

def dynamic_num_beams(prompt: str, tokenizer, max_beams = 4) -> int:
    tokens = len(tokenizer.encode(prompt))
    print(tokens)
    if tokens <= 1024:
        return max_beams
    elif tokens <= 1536:
        return max_beams//2
    else:
        return max_beams//4

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will do!

@rishsriv rishsriv force-pushed the jp/dynamic_num_beams branch from f4ff7cd to d6b99c9 Compare October 18, 2023 07:49
@rishsriv
Copy link
Member

Generally looks good! Will merge as soon as we can make the max_beams change :D

…e scale it down approximately quadratically due to the quadratic nature of attention, and allow users to specify max_beams

We now no longer need the statements to explicitly deal with torch memory before the generate statement.
Update prompt to be the same as sql-coder.
Add tests.
Updated requirements.txt to support peft.
@rishsriv rishsriv force-pushed the jp/dynamic_num_beams branch from d6b99c9 to eb8b528 Compare October 18, 2023 07:53
@rishsriv rishsriv merged commit 0b1d8f9 into main Oct 18, 2023
2 checks passed
@rishsriv rishsriv deleted the jp/dynamic_num_beams branch October 18, 2023 08:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants