Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/kosong/base/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,33 @@ class Message(BaseModel):

partial: bool | None = None

def extract_text(self, include_think: bool = False) -> str:
"""
Extract plain text from the message content.

For string content, returns the string as-is.
For ContentPart list, extracts text from TextPart and optionally ThinkPart.
Other content types (ImageURLPart, AudioURLPart) are ignored.

Args:
include_think: If True, includes ThinkPart content (only when encrypted is None).
If False, ThinkPart content is ignored.

Returns:
Extracted plain text string.
"""

if isinstance(self.content, str):
return self.content

text_parts: list[str] = []
for part in self.content:
if isinstance(part, TextPart):
text_parts.append(part.text)
elif isinstance(part, ThinkPart) and include_think and part.encrypted is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only when part.encrypted is None? encrypted is the raw encrypted message in base64 but doesn't mean that the think field is "encrypted".

text_parts.append(part.think)
return "".join(text_parts)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about joining with \n?


@field_serializer("content")
def serialize_content(
self, content: str | list[ContentPart]
Expand Down
40 changes: 39 additions & 1 deletion src/kosong/context/linear.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved the whole context module into kosong.contrib because we no longer consider it as a core feature.

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
from pathlib import Path
from typing import IO, Protocol, runtime_checkable
from typing import IO, Protocol, Any, runtime_checkable

from kosong.base.message import Message

Expand All @@ -22,6 +22,44 @@ def history(self) -> list[Message]:
def token_count(self) -> int:
return self._storage.token_count

@property
def statistics(self) -> dict[str, Any]:
user_count = 0
assistant_count = 0
tool_count = 0
system_count = 0
for message in self.history:
match message.role:
case "user":
user_count += 1
case "assistant":
assistant_count += 1
case "tool":
tool_count += 1
case "system":
system_count += 1
return {
"token_count": self.token_count,
"message_count": len(self.history),
"user_message_count": user_count,
"assistant_message_count": assistant_count,
"tool_message_count": tool_count,
"system_message_count": system_count,
}

def extract_texts(self, include_think: bool = False) -> list[str]:
"""
Extract plain text from all messages in the context.

Args:
include_think: If True, includes ThinkPart content from messages.
If False, ThinkPart content is ignored.

Returns:
List of extracted text strings, one for each message in the context.
"""
return [message.extract_text(include_think=include_think) for message in self.history]

async def add_message(self, message: Message):
await self._storage.append_message(message)

Expand Down
89 changes: 89 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,92 @@ async def run():
assert f.read() == expected

test_path.unlink()


def test_linear_context_statistics():
context = LinearContext(
storage=MemoryLinearStorage(),
)
assert context.statistics == {
"token_count": 0,
"message_count": 0,
"user_message_count": 0,
"assistant_message_count": 0,
"tool_message_count": 0,
"system_message_count": 0,
}

async def run():
await context.add_message(Message(role="system", content="System prompt"))
await context.add_message(Message(role="user", content="Hello"))
await context.add_message(Message(role="assistant", content="Hi"))
await context.add_message(Message(role="user", content="How are you?"))
await context.add_message(Message(role="assistant", content="I'm fine"))
await context.add_message(Message(role="tool", content="Result", tool_call_id="123"))
return context.statistics

stats = asyncio.run(run())
assert stats == {
"token_count": 0,
"message_count": 6,
"user_message_count": 2,
"assistant_message_count": 2,
"tool_message_count": 1,
"system_message_count": 1,
}


def test_linear_context_extract_texts():
from kosong.base.message import ImageURLPart, TextPart

context = LinearContext(
storage=MemoryLinearStorage(),
)
assert context.extract_texts() == []

async def run():
await context.add_message(Message(role="user", content="Hello"))
await context.add_message(Message(role="assistant", content="Hi there"))
await context.add_message(
Message(
role="user",
content=[
TextPart(text="What is "),
ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/img.png")),
TextPart(text="this?"),
],
)
)
return context.extract_texts()

texts = asyncio.run(run())
assert texts == ["Hello", "Hi there", "What is this?"]


def test_linear_context_extract_texts_with_think():
from kosong.base.message import TextPart, ThinkPart

context = LinearContext(
storage=MemoryLinearStorage(),
)

async def run():
await context.add_message(Message(role="user", content="Hello"))
await context.add_message(
Message(
role="assistant",
content=[
TextPart(text="Let me think..."),
ThinkPart(think="I need to consider this carefully."),
TextPart(text="Here's my answer."),
],
)
)
return context.extract_texts(include_think=False), context.extract_texts(include_think=True)

texts_without_think, texts_with_think = asyncio.run(run())
assert texts_without_think == ["Hello", "Let me think...Here's my answer."]
assert texts_with_think == [
"Hello",
"Let me think...I need to consider this carefully.Here's my answer.",
]
80 changes: 80 additions & 0 deletions tests/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,83 @@ def test_deserialize_from_json_with_content_but_no_tool_calls():
}
message = Message.model_validate(data)
assert message.model_dump(exclude_none=True) == data


def test_extract_text_from_string_content():
message = Message(role="user", content="Hello, world!")
assert message.extract_text() == "Hello, world!"


def test_extract_text_from_text_part():
message = Message(role="user", content=[TextPart(text="Hello, world!")])
assert message.extract_text() == "Hello, world!"


def test_extract_text_ignores_non_text_parts():
message = Message(
role="user",
content=[
TextPart(text="Hello"),
ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png")),
TextPart(text="World"),
AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")),
],
)
assert message.extract_text() == "HelloWorld"


def test_extract_text_with_think_part_excluded():
message = Message(
role="assistant",
content=[
TextPart(text="Hello"),
ThinkPart(think="I'm thinking..."),
TextPart(text="World"),
],
)
assert message.extract_text(include_think=False) == "HelloWorld"


def test_extract_text_with_think_part_included():
message = Message(
role="assistant",
content=[
TextPart(text="Hello"),
ThinkPart(think="I'm thinking..."),
TextPart(text="World"),
],
)
assert message.extract_text(include_think=True) == "HelloI'm thinking...World"


def test_extract_text_excludes_encrypted_think_part():
message = Message(
role="assistant",
content=[
TextPart(text="Hello"),
ThinkPart(think="I'm thinking...", encrypted="signature"),
TextPart(text="World"),
],
)
# Even with include_think=True, encrypted ThinkPart should be excluded
assert message.extract_text(include_think=True) == "HelloWorld"


def test_extract_text_from_empty_content():
message = Message(role="user", content="")
assert message.extract_text() == ""

message2 = Message(role="user", content=[])
assert message2.extract_text() == ""


def test_extract_text_multiple_text_parts():
message = Message(
role="user",
content=[
TextPart(text="First"),
TextPart(text="Second"),
TextPart(text="Third"),
],
)
assert message.extract_text() == "FirstSecondThird"