Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add config for optional parameters in a chat message #2260

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d4ac2c8
feat: add config for optional parameters in a chat message
NJordan72 Jan 14, 2025
ab270e9
chore: cleanup
NJordan72 Jan 14, 2025
17a824a
chore: fix nits and add light docs
NJordan72 Jan 14, 2025
29d23f6
docs: update docs/dataset-formats/conversation.qmd
NJordan72 Jan 15, 2025
a36c4dc
Merge branch 'main' into feat/optional-chat-template-fields
NJordan72 Jan 15, 2025
c8ade2b
feat: configurable message mappings, jinja template analyzer
NJordan72 Jan 16, 2025
b1e8fab
chore: handle bradley terry
NJordan72 Jan 16, 2025
36aa1c0
docs: update docs
NJordan72 Jan 16, 2025
3f96f37
refactor: change order of mappings, improve message transform
NJordan72 Jan 16, 2025
8ea50e3
refactor: make chat awware of property mappings
NJordan72 Jan 16, 2025
0af59f7
chore: remove .python-version
NJordan72 Jan 16, 2025
2197c49
chore: revert change
NJordan72 Jan 16, 2025
1d0cba2
chore: add dataset validation to tests where appropriate
NJordan72 Jan 17, 2025
d0554d0
chore: add dataset validation to tests where appropriate
NJordan72 Jan 17, 2025
31b47e7
chore: clean up handling of ds_cfg
NJordan72 Jan 17, 2025
b2f93ea
chore: recursively serialize config
NJordan72 Jan 17, 2025
ff599a8
make sure to use the return value from validate_config
winglian Jan 17, 2025
126b040
DefaultDict pickle/unpickle fix
winglian Jan 17, 2025
c265115
fix super call for override
winglian Jan 17, 2025
0237fd4
refactor: message fields
NJordan72 Jan 21, 2025
33ef3ae
chore: empty commit
NJordan72 Jan 14, 2025
10ac0fe
tests: validate config before using
NJordan72 Jan 21, 2025
c607a7e
chore: add config validation to all e2e tests
NJordan72 Jan 21, 2025
7683523
chore: add unneeded logging
NJordan72 Jan 21, 2025
d0ccd95
chore: add missed config validation
NJordan72 Jan 21, 2025
ef7c1e1
chore: pass field_messages to prompter
NJordan72 Jan 21, 2025
867f1bc
Merge branch 'main' into feat/optional-chat-template-fields
NJordan72 Jan 21, 2025
1c40a54
test: fix borked test
NJordan72 Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ datasets:
message_field_role: role
# Key for content in each message (default: "content")
message_field_content: content
# Mapping of properties from the input dataset to the chat template. (default: None)
message_property_mappings:

# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/prompt_strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor

return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
Expand Down
3 changes: 1 addition & 2 deletions src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_property_mappings": ds_cfg.get("message_property_mappings", {}),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", None
Expand Down
108 changes: 73 additions & 35 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
"""

import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set

from transformers import ProcessorMixin

from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
Expand All @@ -25,8 +26,7 @@ def __init__(
processor=None,
chat_template=None,
max_length=2048,
message_field_role: str = "role",
message_field_content: str = "content",
message_property_mappings: Optional[Dict[str, str]] = None,
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
roles: Optional[Dict[str, List[str]]] = None,
Expand All @@ -44,8 +44,10 @@ def __init__(
"tool": "tool",
}

self.message_field_role = message_field_role
self.message_field_content = message_field_content
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
chat_template
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, messages_array_name refers to the key of the List[dict]. Should we change the signature so that field_messages is passed to the Prompter as well? To allow passing messages_array_name=field_messages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Just updated that.

self.message_property_mappings = message_property_mappings
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.tokenizer = tokenizer
Expand All @@ -54,6 +56,10 @@ def __init__(
self.max_length = max_length
self.drop_system_message = drop_system_message

@property
def chat_template_msg_variables(self) -> Set[str]:
return self._chat_template_msg_variables

def build_prompt(self, conversation, add_generation_prompt=False, images=None):
if self.processor:
text = self.processor.apply_chat_template(
Expand Down Expand Up @@ -183,6 +189,12 @@ def adjust_train_details(

return adjusted_details

def get_chat_template_msg_variables(
self, chat_template: str, messages_array_name: str = "messages"
) -> Set[str]:
template_analyzer = JinjaTemplateAnalyzer(chat_template)
return template_analyzer.get_message_vars(messages_array_name)


class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Expand Down Expand Up @@ -212,6 +224,10 @@ def __init__(
self.train_on_eos = train_on_eos
self.images = "images"

LOG.info(
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
)

@property
def messages(self):
return self._messages
Expand Down Expand Up @@ -424,61 +440,83 @@ def find_turn(self, turns: list[dict], turn_idx: int):

def get_conversation_thread(self, prompt):
turns = []
optional_keys = [
"tool_calls", # tool that 'assistant' calls
"name", # name of tool given by 'tool'
"tool_call_id", # mistral/mixtral requires this
]
for message in prompt[self.messages]:
transformed_message = self.transform_message(message)
LOG.warning(f"Message: {message}")
LOG.warning(f"Transformed message: {transformed_message}")

turn = {
"role": self.prompter.roles[message[self.prompter.message_field_role]],
**transformed_message,
"training": message.get(self.prompter.message_field_training),
"training_detail": message.get(
self.prompter.message_field_training_detail
),
}

# do not add content if None as it may conflict with some templates due to tools
content = message.get(self.prompter.message_field_content, None)
if content is not None:
turn["content"] = content

for key in optional_keys:
value = message.get(key, None)
if value is not None:
turn[key] = value

turns.append(turn)

if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]

return turns

def transform_message(self, message):
# Build the initial transformed message from the mappings
transformed_message = {
key: message[value]
for key, value in self.prompter.message_property_mappings.items()
if message.get(value) is not None
}

# Map the role if necessary
if "role" in transformed_message:
transformed_message["role"] = self.prompter.roles.get(
transformed_message["role"], transformed_message["role"]
)

# Determine which keys in the original message were not mapped
mapped_values = set(self.prompter.message_property_mappings.values())
remaining_keys = set(message) - mapped_values

# Keep only the properties defined in the chat template
# and not already mapped
for key in self.prompter.chat_template_msg_variables:
if key in remaining_keys:
val = message.get(key)
if val is not None:
transformed_message[key] = val

return transformed_message

def get_images(self, prompt):
return prompt.get(self.images, None)


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
def load(
tokenizer,
cfg,
ds_cfg: Optional[Dict[str, Any]] = None,
processor=None,
):
dataset_config = ds_cfg if ds_cfg else {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")

prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_property_mappings": dataset_config.get(
"message_property_mappings", {}
),
"message_field_training": dataset_config.get("message_field_training", None),
"message_field_training_detail": dataset_config.get(
"message_field_training_detail",
None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
"roles": dataset_config.get("roles"),
"drop_system_message": dataset_config.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
Expand All @@ -487,15 +525,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
"roles_to_train": dataset_config.get("roles_to_train", ["assistant"]),
"train_on_eos": dataset_config.get("train_on_eos", "turn"),
}

strategy = ChatTemplateStrategy(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)

if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
if "field_messages" in dataset_config and hasattr(strategy, "messages"):
strategy.messages = dataset_config["field_messages"]

return strategy
Loading
Loading