diff --git a/langdspy/field_descriptors.py b/langdspy/field_descriptors.py index 6781dd6..5715ec0 100644 --- a/langdspy/field_descriptors.py +++ b/langdspy/field_descriptors.py @@ -39,17 +39,13 @@ def validate_value(self, input: Input, value: Any) -> bool: class HintField(FieldDescriptor): HINT_TOKEN_OPENAI = "💡" HINT_TOKEN_ANTHROPIC = None - def __init__(self, desc: str, formatter: Optional[Callable[[Any], Any]] = None, transformer: Optional[Callable[[Any], Any]] = None, validator: Optional[Callable[[Any], Any]] = None, **kwargs): # Provide a default value for the name parameter, such as an empty string super().__init__("", desc, formatter, transformer, validator, **kwargs) - def _start_format_openai(self): return f"{self.HINT_TOKEN_OPENAI}" - def _start_format_anthropic(self): return f"" - def format_prompt_description(self, llm_type: str): if llm_type == "openai": return f"{self._start_format_openai()} {self.desc}" @@ -59,19 +55,15 @@ def format_prompt_description(self, llm_type: str): class InputField(FieldDescriptor): START_TOKEN_OPENAI = "✅" START_TOKEN_ANTHROPIC = None - def _start_format_openai(self): return f"{self.START_TOKEN_OPENAI}{self.name}" - def _start_format_anthropic(self): return f"<{self.name}>" - def format_prompt_description(self, llm_type: str): if llm_type == "openai": return f"{self._start_format_openai()}: {self.desc}" elif llm_type == "anthropic": - return f"{self._start_format_anthropic()}: {self.desc}" - + return f"{self._start_format_anthropic()}{self.desc}" def format_prompt_value(self, value, llm_type: str): value = self.format_value(value) if llm_type == "openai": @@ -84,11 +76,12 @@ def format_prompt_description(self, llm_type: str): if llm_type == "openai": return f"{self._start_format_openai()}: {self.desc}" elif llm_type == "anthropic": - return f"{self._start_format_anthropic()}: {self.desc}" - + return f"{self._start_format_anthropic()}{self.desc}" def format_prompt_value(self, value, llm_type: str): res = "" if len(value) >= 1: + if llm_type == "anthropic": + res += f"<{self.name}>\n" for i, value in enumerate(value): if i > 0: res += "\n" @@ -97,37 +90,33 @@ def format_prompt_value(self, value, llm_type: str): res += f"{self.START_TOKEN_OPENAI} [{i}]: {value}" elif llm_type == "anthropic": res += f"{value}" + if llm_type == "anthropic": + res += f"\n" else: if llm_type == "openai": res += f"{self._start_format_openai()}: NO VALUES SPECIFIED" elif llm_type == "anthropic": res += f"{self._start_format_anthropic()}NO VALUES SPECIFIED" - return res class OutputField(FieldDescriptor): START_TOKEN_OPENAI = "🔑" START_TOKEN_ANTHROPIC = None - def _start_format_openai(self): return f"{self.START_TOKEN_OPENAI}{self.name}" - def _start_format_anthropic(self): return f"<{self.name}>" - def format_prompt_description(self, llm_type: str): if llm_type == "openai": return f"{self._start_format_openai()}: {self.desc}" elif llm_type == "anthropic": - return f"{self._start_format_anthropic()}: {self.desc}" - + return f"{self._start_format_anthropic()}{self.desc}" def format_prompt_value(self, value, llm_type: str): value = self.format_value(value) if llm_type == "openai": return f"{self._start_format_openai()}: {value}" elif llm_type == "anthropic": return f"{self._start_format_anthropic()}{value}" - def format_prompt(self, llm_type: str): if llm_type == "openai": return f"{self._start_format_openai()}:" @@ -143,11 +132,10 @@ def __init__(self, name: str, desc: str, enum: Enum, **kwargs): kwargs['validator'] = validators.is_one_of kwargs['choices'] = [e.name for e in enum] super().__init__(name, desc, **kwargs) - def format_prompt_description(self, llm_type: str): enum = self.kwargs.get('enum') choices_str = ", ".join([e.name for e in enum]) if llm_type == "openai": return f"{self._start_format_openai()}: One of: {choices_str} - {self.desc}" elif llm_type == "anthropic": - return f"{self._start_format_anthropic()}: One of: {choices_str} - {self.desc}" \ No newline at end of file + return f"{self._start_format_anthropic()}One of: {choices_str} - {self.desc}" \ No newline at end of file diff --git a/tests/test_prompt_formatting.py b/tests/test_prompt_formatting.py index 36c802a..7dfccbe 100644 --- a/tests/test_prompt_formatting.py +++ b/tests/test_prompt_formatting.py @@ -29,8 +29,8 @@ def test_format_prompt_anthropic(): formatted_prompt = prompt_runner.template._format_anthropic_prompt(trained_state=None, use_training=True, input="test input") assert "Hint field" in formatted_prompt - assert ": Input field" in formatted_prompt - assert ": Output field" in formatted_prompt + assert "Input field" in formatted_prompt + assert "Output field" in formatted_prompt assert "test input" in formatted_prompt assert "" in formatted_prompt