Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add config for optional parameters in a chat message #2260

Open
wants to merge 28 commits into
base: main
Choose a base branch
from

Conversation

NJordan72
Copy link
Contributor

Description

Currently only role/content properties are passed into the chat template (with the exception of some hardcoded optional params that are model specific). This PR creates a configuration property to add additional optional parameters.

Motivation and Context

  • Chat templates are not restricted to using role/content properties and we should allow users, especially those creating their own custom chat templates to pass in anything via the message.
  • OPEN QUESTION: Why restrict the properties passed into the chat template in the first place? Couldn't we just pass the entire message in and if the template doesn't use a particular property they just become no-ops?

How has this been tested?

  • Added unit test

Copy link
Collaborator

@winglian winglian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor nits, but overall looks like a good improvement to chat templates.

@winglian winglian requested a review from NanoCode012 January 14, 2025 16:30
@NJordan72 NJordan72 marked this pull request as ready for review January 14, 2025 19:58
@NJordan72
Copy link
Contributor Author

A few minor nits, but overall looks like a good improvement to chat templates.

fixed.

Copy link
Collaborator

@NanoCode012 NanoCode012 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR.

Why restrict the properties passed into the chat template in the first place?

The initial code was written when it was mainly role/content. I expanded it recently for tool_calling datasets, so this is a welcome change.

Couldn't we just pass the entire message in and if the template doesn't use a particular property they just become no-ops?

This could be an alternative, but I think the current solution works fine?

docs/dataset-formats/conversation.qmd Outdated Show resolved Hide resolved
tests/prompt_strategies/test_chat_templates_advanced.py Outdated Show resolved Hide resolved
docs/config.qmd Outdated
Comment on lines 140 to 141
# Fields that will be passed to the chat template if present in the message. (Optional[List[str]], default: None)
optional_message_fields:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we say we include tool_calls, name... by default based on your implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to go either way. Right now we include them as hardcoded, but I'd be happy to refactor it to just make them the default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the downside of making it the default is that if someone wants to add an additional field, they -- at that point -- have to remember to add back the defaults as well.

It is why my slight preference is just to pass the entire message into the chat template and let the template itself decide what it cares about.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other option... Would be to parse out all of the arguments from the jinja template and pass in any/all of those. Jinja provides a pretty easy way to extract the variables from the AST (per documentation, I haven't tried it yet).

We could even warn if the template is expecting something it isn't being passed which might be a nice quality of life addition.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should this be a List[str] (current)

or a List[Dict[str, str]] which would allow us to map alternate field names to the field in the jinja template?

The latter more closely matches the meaning of the other args like message_field_content, message_field_role

@NJordan72
Copy link
Contributor Author

I just updated this to be a little more thoughtful / robust...

  • I removed the optional_message_fields option and replaced it with a message_property_mappings option.
  • message_property_mappings is basically a more generalized version of the message_field_role and message_field_content configuration options. Those continue to work and are just sucked into the message_property_mappings if provided. The idea is you can map any input property to any message property using this configuration.
  • I then had to update chat_template.py to use the new mappings configuration property to transform the message.
  • I added a JinjaTemplateAnalyzer that introspects the jinja template to figure out which properties a template is expecting (this is generalized, but it can also look at the properties specifically for the messages, which is what we care about here.
  • We then remove any properties from the transformed message that the template does not need/expect.
  • I've taken out special handling of tool_call_id, etc. as it is no longer needed an is handled by the introspection of the template itself
  • I've added test coverage for all of the new stuff and all of the old tests continue to pass

Comment on lines 77 to 82
message_field_role="role",
message_field_content="content",
message_property_mappings={
"from": "role",
"value": "content",
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to invert the property mappings so that the key is role and content. Because we always know we need to expect those keys, and the mapping is the field in the dataset that we should expect to find them.

Just for the sake of completeness and being able to verify the previous tests in CI and ensure backwards compatibility, It might be good to keep these tests with message_field_role/message_field_content and add a new test suite for message_property_mappings, even if only for a single model architecture.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These specific tests are hard to keep as they were, because the signature of the function is no longer expecting message_field_role/message_field_content

Instead we are shuttling message_field_role/message_field_content into message_property_mappings as part of the config validation.

I have added a test to the model to make sure they are indeed making it into the mappings.

@winglian
Copy link
Collaborator

Here's the stack trace locally of the broken test in CI currently

  File "/Users/wing/Projects/ml/axolotl/src/axolotl/prompt_strategies/chat_template.py", line 275, in tokenize_prompt
    input_ids = self.prompter.build_prompt(turns)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wing/Projects/ml/axolotl/src/axolotl/prompt_strategies/chat_template.py", line 84, in build_prompt
    return self.tokenizer.apply_chat_template(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wing/.pyenv/versions/3.11.10/envs/axolotl-3.11/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 1683, in apply_chat_template
    rendered_chat = compiled_template.render(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wing/.pyenv/versions/3.11.10/envs/axolotl-3.11/lib/python3.11/site-packages/jinja2/environment.py", line 1304, in render
    self.environment.handle_exception()
  File "/Users/wing/.pyenv/versions/3.11.10/envs/axolotl-3.11/lib/python3.11/site-packages/jinja2/environment.py", line 939, in handle_exception
    raise rewrite_traceback_stack(source=source)
  File "<template>", line 1, in top-level template code
jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'role'

and the dataset configuration that is working on main but is broken in the tests

chat_template: llama3
datasets:
  - path: mlabonne/FineTome-100k
    type: chat_template
    split: train[:20%]
    field_messages: conversations
    message_field_role: from
    message_field_content: value

@NJordan72
Copy link
Contributor Author

Still looking into the e2e failure (if this push didn't fix it).

@NJordan72
Copy link
Contributor Author

I tracked down the problem with the failing e2e test. I'm sure there are some further nits, feedback which I'm happy to incorporate. Some other thoughts below that explain some of the changes and then maybe a broader topic that we can take off of this thread if there is any interest.

--

I will say that the one challenge I had is there seems to be a lack of consistency as to if the config that is getting passed around to all of the various classes/functions is supposed to be the Dict representation of the Pydantic model or the model itself.

Many of the tests are obviously just constructing a Dict that represents the config which obviously makes sense as to why they would want to pass that around, but it feels like it would be nicer if the Pydantic model was the lingua franca across the entire project.

It feels like the config is the "special sauce" in some ways for the project. Instead of having to write 1000 lines of boilerplate code to get something running, just craft a fairly simple (I'm probably being a little generous here) config file and you're on your way. I know a lot of the complexity of the project lies in how to deal with that config file, but it feels like the config itself is something of a second class citizen insofar that the Dict representation of it seems to be used in a bunch of places and there are bits of the config that seem under documented and somewhat unclear.

Maybe I'm not seeing the forest from the trees here (and Python is certainly not my bailiwick) , but I'd be interested in helping tighten things up around this if there is any interest in it.

@@ -311,6 +311,9 @@ class KTODataset(BaseModel):
revision: Optional[str] = None


DatasetConfig = Union[SFTDataset, DPODataset, KTODataset]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you intend to incorporate this into the line far below?

datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None  # type: ignore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep. thanks.

@winglian
Copy link
Collaborator

@NJordan72 Thanks for your thoughts and insights. We agree that the docs are definitely lacking and we are on a mission this first few months of the year to improve in that area. The evolution of the project started with a basic DefaultDict with naive validation functions written against it. Only last year did we convert the validation to use pydantic, hence the sort of mismatch of using pydantic up front and then the conversion to a dict after that. We're definitely open to making the pydantic model the common language through the project, but in places there will still be a mismatch when we're using the lower level dataclasses that we're extending from HF transformers. I'd love your thoughts on how you might approach this.

Comment on lines 79 to 80
validate_config(cfg)
normalize_config(cfg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good call to add the validation in this test as it surfaces a new bug with parsing the chat_dataset. It seems we don't do this consistently across the rest of the tests, and seems like we should add this to more of them to make sure the change for the chat_dataset class doesn't affect other logic.

Comment on lines 47 to 49
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
chat_template
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, messages_array_name refers to the key of the List[dict]. Should we change the signature so that field_messages is passed to the Prompter as well? To allow passing messages_array_name=field_messages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Just updated that.

Comment on lines 214 to 228
if "message_field_role" in data:
if (
"role" in data["message_property_mappings"]
and data["message_property_mappings"]["role"]
!= data["message_field_role"]
):
raise ValueError(
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
)
data["message_property_mappings"]["role"] = (
data["message_field_role"] or "role"
)
elif "role" not in data["message_property_mappings"]:
data["message_property_mappings"]["role"] = "role"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After setting the property, should we drop the old fields? (Same for content below)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'll leave the property there for now. My intention is to come back and implement #2271 in a way that gives us much better type safety and validation logic, at which point it should be easier to to

Comment on lines 54 to 63
message_field_role = ds_cfg.get("message_field_role")
message_field_content = ds_cfg.get("message_field_content")
message_property_mappings = ds_cfg.get("message_property_mappings")
message_field_training = ds_cfg.get("message_field_training")

builder_kwargs = {}
if field_messages:
builder_kwargs["conversations_field"] = field_messages
if message_field_role:
builder_kwargs["message_field_role"] = message_field_role
if message_field_content:
builder_kwargs["message_field_content"] = message_field_content
if message_property_mappings and "role" in message_property_mappings:
builder_kwargs["message_field_role"] = message_property_mappings["role"]
if message_property_mappings and "content" in message_property_mappings:
builder_kwargs["message_field_content"] = message_property_mappings["content"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes should probably reverted and go into chat_template.py. We can address changing the messages type in a later PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted in the most recent revision

@NJordan72
Copy link
Contributor Author

fyi, will crank the changes out today and tomorrow.

@NJordan72
Copy link
Contributor Author

@winglian @NanoCode012 looks like the remaining failure is unrelated to my changes here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants