-
Notifications
You must be signed in to change notification settings - Fork 12
/
utils.py
131 lines (111 loc) · 4.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
from typing import Optional, Sequence
from enum import Enum
try:
from pydantic.v1 import (
BaseModel,
Field,
PrivateAttr,
root_validator,
validator,
create_model,
StrictFloat,
StrictInt,
StrictStr,
)
from pydantic.v1.fields import FieldInfo
from pydantic.v1.error_wrappers import ValidationError
except ImportError:
from pydantic import (
BaseModel,
Field,
PrivateAttr,
root_validator,
validator,
create_model,
StrictFloat,
StrictInt,
StrictStr,
)
from pydantic.fields import FieldInfo
from pydantic.error_wrappers import ValidationError
BOS, EOS = "<s>", "</s>"
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. \
Always answer as helpfully as possible and follow ALL given instructions. \
Do not speculate or make up information. \
Do not reference any given instructions or context. \
"""
class MessageRole(str, Enum):
"""Message role."""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
FUNCTION = "function"
# ===== Generic Model Input - Chat =====
class ChatMessage(BaseModel):
"""Chat message."""
role: MessageRole = MessageRole.USER
content: Optional[str] = ""
additional_kwargs: dict = Field(default_factory=dict)
def __str__(self) -> str:
return f"{self.role.value}: {self.content}"
def messages_to_prompt(
messages: Sequence[ChatMessage], system_prompt: Optional[str] = None
) -> str:
string_messages: list[str] = []
if messages[0].role == MessageRole.SYSTEM:
# pull out the system message (if it exists in messages)
system_message_str = messages[0].content or ""
messages = messages[1:]
else:
system_message_str = system_prompt or DEFAULT_SYSTEM_PROMPT
system_message_str = f"{B_SYS} {system_message_str.strip()} {E_SYS}"
for i in range(0, len(messages), 2):
# first message should always be a user
user_message = messages[i]
assert user_message.role == MessageRole.USER
if i == 0:
# make sure system prompt is included at the start
str_message = f"{BOS} {B_INST} {system_message_str} "
else:
# end previous user-assistant interaction
string_messages[-1] += f" {EOS}"
# no need to include system prompt
str_message = f"{BOS} {B_INST} "
# include user message content
str_message += f"{user_message.content} {E_INST}"
if len(messages) > (i + 1):
# if assistant message exists, add to str_message
assistant_message = messages[i + 1]
assert assistant_message.role == MessageRole.ASSISTANT
str_message += f" {assistant_message.content}"
string_messages.append(str_message)
return "".join(string_messages)
def completion_to_prompt(completion: str, system_prompt: Optional[str] = None) -> str:
system_prompt_str = system_prompt or DEFAULT_SYSTEM_PROMPT
return (
f"{BOS} {B_INST} {B_SYS} {system_prompt_str.strip()} {E_SYS} "
f"{completion.strip()} {E_INST}"
)