Skip to content

Commit 0b34be4

Browse files
authored
refactor(langchain): refactor unit test stub classes (#32209)
See #32098 (comment)
1 parent 6f3169e commit 0b34be4

File tree

1 file changed

+15
-40
lines changed

1 file changed

+15
-40
lines changed
Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,23 @@
1-
from typing import Any
1+
from langchain_core.messages import AIMessage, AIMessageChunk
2+
from pydantic import BaseModel
23

3-
from langchain_core.documents import Document
4-
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
5-
6-
7-
class AnyStr(str):
8-
__slots__ = ()
94

5+
class _AnyIDMixin(BaseModel):
106
def __eq__(self, other: object) -> bool:
11-
return isinstance(other, str)
12-
13-
__hash__ = str.__hash__
14-
15-
16-
# The code below creates version of pydantic models
17-
# that will work in unit tests with AnyStr as id field
18-
# Please note that the `id` field is assigned AFTER the model is created
19-
# to workaround an issue with pydantic ignoring the __eq__ method on
20-
# subclassed strings.
21-
22-
23-
def _AnyIdDocument(**kwargs: Any) -> Document:
24-
"""Create a document with an id field."""
25-
message = Document(**kwargs)
26-
message.id = AnyStr()
27-
return message
28-
7+
if isinstance(other, BaseModel):
8+
dump = self.model_dump()
9+
dump.pop("id")
10+
other_dump = other.model_dump()
11+
other_dump.pop("id")
12+
return dump == other_dump
13+
return False
2914

30-
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
31-
"""Create ai message with an any id field."""
32-
message = AIMessage(**kwargs)
33-
message.id = AnyStr()
34-
return message
15+
__hash__ = None # type: ignore[assignment]
3516

3617

37-
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
38-
"""Create ai message with an any id field."""
39-
message = AIMessageChunk(**kwargs)
40-
message.id = AnyStr()
41-
return message
18+
class _AnyIdAIMessage(AIMessage, _AnyIDMixin):
19+
"""AIMessage with any ID."""
4220

4321

44-
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
45-
"""Create a human with an any id field."""
46-
message = HumanMessage(**kwargs)
47-
message.id = AnyStr()
48-
return message
22+
class _AnyIdAIMessageChunk(AIMessageChunk, _AnyIDMixin):
23+
"""AIMessageChunk with any ID."""

0 commit comments

Comments
 (0)