Skip to content

Commit

Permalink
fix ChatMessageChunk concat error (langchain-ai#10174)
Browse files Browse the repository at this point in the history
<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. These live is docs/extras
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17, @rlancemartin.
 -->

- Description: fix `ChatMessageChunk` concat error 
- Issue: langchain-ai#10173 
- Dependencies: None
- Tag maintainer: @baskaryan, @eyurtsev, @rlancemartin
- Twitter handle: None

---------

Co-authored-by: wangshuai.scotty <[email protected]>
Co-authored-by: Nuno Campos <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2023
1 parent 4322b24 commit 88a0207
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 4 deletions.
59 changes: 56 additions & 3 deletions libs/langchain/langchain/schema/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
# If both are (subclasses of) BaseMessageChunk,
# concat into a single BaseMessageChunk

if isinstance(self, ChatMessageChunk):
return self.__class__(
role=self.role,
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return self.__class__(
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
Expand Down Expand Up @@ -168,7 +176,22 @@ def type(self) -> str:
class AIMessageChunk(AIMessage, BaseMessageChunk):
"""A Message chunk from an AI."""

pass
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)

return self.__class__(
example=self.example,
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)

return super().__add__(other)


class SystemMessage(BaseMessage):
Expand Down Expand Up @@ -203,7 +226,22 @@ def type(self) -> str:
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""A Function Message chunk."""

pass
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, FunctionMessageChunk):
if self.name != other.name:
raise ValueError(
"Cannot concatenate FunctionMessageChunks with different names."
)

return self.__class__(
name=self.name,
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)

return super().__add__(other)


class ChatMessage(BaseMessage):
Expand All @@ -221,7 +259,22 @@ def type(self) -> str:
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
"""A Chat Message chunk."""

pass
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ChatMessageChunk):
if self.role != other.role:
raise ValueError(
"Cannot concatenate ChatMessageChunks with different roles."
)

return self.__class__(
role=self.role,
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)

return super().__add__(other)


def _message_to_dict(message: BaseMessage) -> dict:
Expand Down
60 changes: 59 additions & 1 deletion libs/langchain/tests/unit_tests/schema/test_messages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from langchain.schema.messages import AIMessageChunk, HumanMessageChunk
import pytest

from langchain.schema.messages import (
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
)


def test_message_chunks() -> None:
Expand Down Expand Up @@ -36,3 +43,54 @@ def test_message_chunks() -> None:
}
},
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501


def test_chat_message_chunks() -> None:
assert ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
role="User", content=" indeed."
) == ChatMessageChunk(
role="User", content="I am indeed."
), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"

with pytest.raises(ValueError):
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
role="Assistant", content=" indeed."
)

assert ChatMessageChunk(role="User", content="I am") + AIMessageChunk(
content=" indeed."
) == ChatMessageChunk(
role="User", content="I am indeed."
), "ChatMessageChunk + other MessageChunk should be a ChatMessageChunk with the left side's role" # noqa: E501

assert AIMessageChunk(content="I am") + ChatMessageChunk(
role="User", content=" indeed."
) == AIMessageChunk(
content="I am indeed."
), "Other MessageChunk + ChatMessageChunk should be a MessageChunk as the left side" # noqa: E501


def test_function_message_chunks() -> None:
assert FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
name="hello", content=" indeed."
) == FunctionMessageChunk(
name="hello", content="I am indeed."
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"

with pytest.raises(ValueError):
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
name="bye", content=" indeed."
)


def test_ani_message_chunks() -> None:
assert AIMessageChunk(example=True, content="I am") + AIMessageChunk(
example=True, content=" indeed."
) == AIMessageChunk(
example=True, content="I am indeed."
), "AIMessageChunk + AIMessageChunk should be a AIMessageChunk"

with pytest.raises(ValueError):
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
example=False, content=" indeed."
)

0 comments on commit 88a0207

Please sign in to comment.