Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when using stopping_criteria in .generate if remote=True #137

Open
Butanium opened this issue May 20, 2024 · 5 comments
Open

Error when using stopping_criteria in .generate if remote=True #137

Butanium opened this issue May 20, 2024 · 5 comments

Comments

@Butanium
Copy link
Contributor

Remote execution does not support stopping_criteria right now:

from nnsight import LanguageModel
from transformers import StoppingCriteria

class Stopping(StoppingCriteria):
    def __init__(self):
        pass

    def __call__(self, input_ids, _scores, **_kwargs):
        return False  # Continue generation

    def __len__(self):
        return 1

    def __iter__(self):
        yield self

nn_model = LanguageModel("meta-llama/Llama-2-70b-hf")

stopping_criteria = Stopping()
with nn_model.generate("hello", remote=True, stopping_criteria=stopping_criteria) as tracer:
    out = nn_model.generator.output.save()
print(nn_model.tokenizer.decode(out[0]))

Error trace:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[47], line 20
     17 nn_model = LanguageModel("meta-llama/Llama-2-70b-hf")
     19 stopping_criteria = StoppingCriteria()
---> 20 with nn_model.generate("hello", remote=True, stopping_criteria=stopping_criteria) as tracer:
     21     out = nn_model.generator.output.save()
     22 print(nn_model.tokenizer.decode(out[0]))

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/nnsight/models/mixins/Generation.py:11, in GenerationMixin.generate(self, *args, **kwargs)
      9 def generate(self, *args, **kwargs) -> Runner:
---> 11     return self.trace(*args, generate=True, **kwargs)

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/nnsight/models/NNsightModel.py:196, in NNsight.trace(self, trace, invoker_args, scan, *inputs, **kwargs)
    193         return output.value
    195     # Otherwise open an invoker context with the give args.
--> 196     runner.invoke(*inputs, **invoker_args).__enter__()
    198 # If trace is False, you had to have provided an input.
    199 if not trace:

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/nnsight/contexts/Invoker.py:69, in Invoker.__enter__(self)
     64     with FakeTensorMode(
     65         allow_non_fake_inputs=True,
     66         shape_env=ShapeEnv(assume_static_by_default=True),
     67     ) as fake_mode:
     68         with FakeCopyMode(fake_mode):
---> 69             self.tracer._model._execute(
     70                 *copy.deepcopy(self.inputs),
     71                 **copy.deepcopy(self.tracer._kwargs),
     72             )
     74     self.scanning = False
     76 else:

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/nnsight/models/mixins/Generation.py:19, in GenerationMixin._execute(self, prepared_inputs, generate, *args, **kwargs)
     13 def _execute(
     14     self, prepared_inputs: Any, *args, generate: bool = False, **kwargs
     15 ) -> Any:
     17     if generate:
---> 19         return self._execute_generate(prepared_inputs, *args, **kwargs)
     21     return self._execute_forward(prepared_inputs, *args, **kwargs)

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/nnsight/models/LanguageModel.py:293, in LanguageModel._execute_generate(self, prepared_inputs, max_new_tokens, *args, **kwargs)
    287 def _execute_generate(
    288     self, prepared_inputs: Any, *args, max_new_tokens=1, **kwargs
    289 ):
    291     device = next(self._model.parameters()).device
--> 293     output = self._model.generate(
    294         *args,
    295         **prepared_inputs.to(device),
    296         max_new_tokens=max_new_tokens,
    297         **kwargs,
    298     )
    300     self._model.generator(output)
    302     return output

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/transformers/generation/utils.py:1533, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1521 prepared_logits_processor = self._get_logits_processor(
   1522     generation_config=generation_config,
   1523     input_ids_seq_length=input_ids_length,
   (...)
   1529     negative_prompt_attention_mask=negative_prompt_attention_mask,
   1530 )
   1532 # 9. prepare stopping criteria
-> 1533 prepared_stopping_criteria = self._get_stopping_criteria(
   1534     generation_config=generation_config, stopping_criteria=stopping_criteria
   1535 )
   1536 # 10. go into different generation modes
   1537 if generation_mode == GenerationMode.ASSISTED_GENERATION:

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/transformers/generation/utils.py:903, in GenerationMixin._get_stopping_criteria(self, generation_config, stopping_criteria)
    901 if generation_config.eos_token_id is not None:
    902     criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
--> 903 criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
    904 return criteria

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/transformers/generation/utils.py:911, in GenerationMixin._merge_criteria_processor_list(self, default_list, custom_list)
    906 def _merge_criteria_processor_list(
    907     self,
    908     default_list: Union[LogitsProcessorList, StoppingCriteriaList],
    909     custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
    910 ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
--> 911     if len(custom_list) == 0:
    912         return default_list
    913     for default in default_list:

TypeError: object of type 'StoppingCriteria' has no len()
@JadenFiotto-Kaufman
Copy link
Member

@Butanium

I'm not getting the error you posted, but instead a maximum recursion error.

I need to have a better error for this, but how things work now you can't send arbitrary objects to the server.

It has to be one of:
list
dict
tuple
int, string, None, float, bool
Tensor
slice
whitelisted function
nnsight Node

You can see this here: https://github.com/ndif-team/nnsight/blob/main/src/nnsight/pydantics/format/types.py

@Butanium
Copy link
Contributor Author

Hi, sorry I posted the wrong trace, I got a recursion error too 😅
I'll look into it

@Butanium
Copy link
Contributor Author

So if I understand correctly, as every stopping criteria is a different class inheriting StoppingCriteria it might not be possible for nnsight to support this argument on remote execution ?

@JadenFiotto-Kaufman
Copy link
Member

So if I understand correctly, as every stopping criteria is a different class inheriting StoppingCriteria it might not be possible for nnsight to support this argument on remote execution ?

Yeah ndif/nnsight works with a custom serialized format. So it only supports types we explicitly define. Otherwise anyone could execute arbitrary code with arbitrary classes. If StoppingCriteria seems quite useful maybe I could add it.

@Butanium
Copy link
Contributor Author

Stopping criteria is an abstract class meant to be inherited. If I understand correctly you'd need to manually add a set of class inheriting StoppingCriteria right ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants