Skip to content

Commit ae45e2a

Browse files
harini-venkataramanpre-commit-ci[bot]chandrasekharan-zipstacktahierhussain
authored
[FEAT] Support for table extraction in Prompt studio (#564)
* Support for table extraction in Prompt studio * Refactors-Table extraction * Sonar issues * Removing plugins * Adding missed urls * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enrich error message Co-authored-by: Chandrasekharan M <[email protected]> Signed-off-by: harini-venkataraman <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Warn cloud exceptions Co-authored-by: Chandrasekharan M <[email protected]> Signed-off-by: harini-venkataraman <[email protected]> * Update prompt-service/src/unstract/prompt_service/main.py Co-authored-by: Chandrasekharan M <[email protected]> Signed-off-by: harini-venkataraman <[email protected]> * Skip table for single pass * Adding method alias * Refactors * Method refactors * Constants * FEAT: Table Extractor (FE) (#590) * FE changes for table extraction * UI improvements related to table extraction * Fixed Eslint issues * Passed the output prop * Fixed component rendering and sonar issue * Added optional chaining * Output in payload * Sonar fixes * Sonar fixes * Bump structure tool version * Remove metadata includions * Add TODO for plugging logic to tools --------- Signed-off-by: harini-venkataraman <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Chandrasekharan M <[email protected]> Co-authored-by: Tahier Hussain <[email protected]>
1 parent 8d21cb1 commit ae45e2a

File tree

20 files changed

+799
-348
lines changed

20 files changed

+799
-348
lines changed

backend/backend/urls.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,17 @@
118118
# Clone urls
119119
try:
120120

121-
import pluggable_apps.clone.urls # noqa # pylint: disable=unused-import
122-
123121
urlpatterns += [
124122
path("", include("pluggable_apps.clone.urls")),
125123
]
126124
except ImportError:
127125
pass
126+
127+
try:
128+
import pluggable_apps.apps.table_settings # noqa # pylint: disable=unused-import
129+
130+
urlpatterns += [
131+
path("", include("pluggable_apps.apps.table_settings.urls")),
132+
]
133+
except ImportError:
134+
pass
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import logging
2+
import os
3+
from importlib import import_module
4+
from typing import Any
5+
6+
from django.apps import apps
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class ModifierConfig:
12+
"""Loader config for extraction plugins."""
13+
14+
PLUGINS_APP = "plugins"
15+
PLUGIN_DIR = "modifier"
16+
MODULE = "module"
17+
METADATA = "metadata"
18+
METADATA_NAME = "name"
19+
METADATA_SERVICE_CLASS = "service_class"
20+
METADATA_IS_ACTIVE = "is_active"
21+
22+
23+
def load_plugins() -> list[Any]:
24+
"""Iterate through the extraction plugins and register them."""
25+
plugins_app = apps.get_app_config(ModifierConfig.PLUGINS_APP)
26+
package_path = plugins_app.module.__package__
27+
modifier_dir = os.path.join(plugins_app.path, ModifierConfig.PLUGIN_DIR)
28+
modifier_package_path = f"{package_path}.{ModifierConfig.PLUGIN_DIR}"
29+
modifier_plugins: list[Any] = []
30+
31+
if not os.path.exists(modifier_dir):
32+
return modifier_plugins
33+
34+
for item in os.listdir(modifier_dir):
35+
# Loads a plugin if it is in a directory.
36+
if os.path.isdir(os.path.join(modifier_dir, item)):
37+
modifier_module_name = item
38+
# Loads a plugin if it is a shared library.
39+
# Module name is extracted from shared library name.
40+
elif item.endswith(".so"):
41+
modifier_module_name = item.split(".")[0]
42+
else:
43+
continue
44+
try:
45+
full_module_path = f"{modifier_package_path}.{modifier_module_name}"
46+
module = import_module(full_module_path)
47+
metadata = getattr(module, ModifierConfig.METADATA, {})
48+
49+
if metadata.get(ModifierConfig.METADATA_IS_ACTIVE, False):
50+
modifier_plugins.append(
51+
{
52+
ModifierConfig.MODULE: module,
53+
ModifierConfig.METADATA: module.metadata,
54+
}
55+
)
56+
logger.info(
57+
"Loaded modifier plugin: %s, is_active: %s",
58+
module.metadata[ModifierConfig.METADATA_NAME],
59+
module.metadata[ModifierConfig.METADATA_IS_ACTIVE],
60+
)
61+
else:
62+
logger.info(
63+
"modifier plugin %s is not active.",
64+
modifier_module_name,
65+
)
66+
except ModuleNotFoundError:
67+
logger.warning("No prompt modifier plugins loaded")
68+
69+
if len(modifier_plugins) == 0:
70+
logger.info("No modifier plugins found.")
71+
72+
return modifier_plugins
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Generated by Django 4.2.1 on 2024-08-07 14:20
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
("prompt_studio", "0006_alter_toolstudioprompt_prompt_key_and_more"),
10+
]
11+
12+
operations = [
13+
migrations.AlterField(
14+
model_name="toolstudioprompt",
15+
name="enforce_type",
16+
field=models.TextField(
17+
blank=True,
18+
choices=[
19+
("Text", "Response sent as Text"),
20+
("number", "Response sent as number"),
21+
("email", "Response sent as email"),
22+
("date", "Response sent as date"),
23+
("boolean", "Response sent as boolean"),
24+
("json", "Response sent as json"),
25+
("table", "Response sent as table"),
26+
],
27+
db_comment="Field to store the type in which the response to be returned.",
28+
default="Text",
29+
),
30+
),
31+
]

backend/prompt_studio/prompt_studio/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class EnforceType(models.TextChoices):
2020
DATE = "date", "Response sent as date"
2121
BOOLEAN = "boolean", "Response sent as boolean"
2222
JSON = "json", "Response sent as json"
23+
TABLE = "table", "Response sent as table"
2324

2425
class PromptType(models.TextChoices):
2526
PROMPT = "PROMPT", "Response sent as Text"

backend/prompt_studio/prompt_studio_core/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class ToolStudioPromptKeys:
8888
CONTEXT = "context"
8989
METADATA = "metadata"
9090
INCLUDE_METADATA = "include_metadata"
91+
TXT_EXTENTION = ".txt"
92+
TABLE = "table"
93+
EXTRACT = "extract"
9194
PLATFORM_POSTAMBLE = "platform_postamble"
9295
SUMMARIZE_AS_SOURCE = "summarize_as_source"
9396

backend/prompt_studio/prompt_studio_core/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,13 @@ class MaxProfilesReachedError(APIException):
6767
f"Maximum number of profiles (max {ProfileManagerKeys.MAX_PROFILE_COUNT})"
6868
" per prompt studio project has been reached."
6969
)
70+
71+
72+
class OperationNotSupported(APIException):
73+
status_code = 403
74+
default_detail = (
75+
"This feature is not supported "
76+
"in the open-source version. "
77+
"Please check our cloud or enterprise on-premise offering "
78+
"for access to this functionality."
79+
)

backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from django.conf import settings
1313
from django.db.models.manager import BaseManager
1414
from file_management.file_management_helper import FileManagerHelper
15+
from prompt_studio.modifier_loader import ModifierConfig
16+
from prompt_studio.modifier_loader import load_plugins as load_modifier_plugins
1517
from prompt_studio.prompt_profile_manager.models import ProfileManager
1618
from prompt_studio.prompt_profile_manager.profile_manager_helper import (
1719
ProfileManagerHelper,
@@ -28,6 +30,7 @@
2830
EmptyPromptError,
2931
IndexingAPIError,
3032
NoPromptsFound,
33+
OperationNotSupported,
3134
PermissionError,
3235
)
3336
from prompt_studio.prompt_studio_core.models import CustomTool
@@ -53,6 +56,8 @@
5356

5457
logger = logging.getLogger(__name__)
5558

59+
modifier_loader = load_modifier_plugins()
60+
5661

5762
class PromptStudioHelper:
5863
"""Helper class for Custom tool operations."""
@@ -434,6 +439,10 @@ def _execute_single_prompt(
434439
text_processor: Optional[type[Any]] = None,
435440
):
436441
prompt_instance = PromptStudioHelper._fetch_prompt_from_id(id)
442+
443+
if prompt_instance.enforce_type == TSPKeys.TABLE and not modifier_loader:
444+
raise OperationNotSupported()
445+
437446
prompt_name = prompt_instance.prompt_key
438447
PromptStudioHelper._publish_log(
439448
{
@@ -523,7 +532,9 @@ def _execute_prompts_in_single_pass(
523532
prompts = [
524533
prompt
525534
for prompt in prompts
526-
if prompt.prompt_type != TSPKeys.NOTES and prompt.active
535+
if prompt.prompt_type != TSPKeys.NOTES
536+
and prompt.active
537+
and prompt.enforce_type != TSPKeys.TABLE
527538
]
528539
if not prompts:
529540
logger.error(f"[{tool_id or 'NA'}] No prompts found for id: {id}")
@@ -583,6 +594,19 @@ def _get_document_path(org_id, user_id, tool_id, doc_name):
583594
)
584595
return str(Path(doc_path) / doc_name)
585596

597+
@staticmethod
598+
def _get_extract_or_summary_document_path(
599+
org_id, user_id, tool_id, doc_name, doc_type
600+
) -> str:
601+
doc_path = FileManagerHelper.handle_sub_directory_for_tenants(
602+
org_id=org_id,
603+
user_id=user_id,
604+
tool_id=tool_id,
605+
is_create=False,
606+
)
607+
extracted_doc_name = Path(doc_name).stem + TSPKeys.TXT_EXTENTION
608+
return str(Path(doc_path) / doc_type / extracted_doc_name)
609+
586610
@staticmethod
587611
def _handle_response(
588612
response,
@@ -698,7 +722,7 @@ def _fetch_response(
698722
"status": IndexingStatus.PENDING_STATUS.value,
699723
"message": IndexingStatus.DOCUMENT_BEING_INDEXED.value,
700724
}
701-
725+
tool_id = str(tool.tool_id)
702726
output: dict[str, Any] = {}
703727
outputs: list[dict[str, Any]] = []
704728
grammer_dict = {}
@@ -738,6 +762,10 @@ def _fetch_response(
738762
attr_val = getattr(prompt, attr)
739763
output[TSPKeys.EVAL_SETTINGS][attr] = attr_val
740764

765+
output = PromptStudioHelper.fetch_table_settings_if_enabled(
766+
doc_name, prompt, org_id, user_id, tool_id, output
767+
)
768+
741769
outputs.append(output)
742770

743771
tool_settings = {}
@@ -754,8 +782,6 @@ def _fetch_response(
754782
settings, TSPKeys.PLATFORM_POSTAMBLE.upper(), ""
755783
)
756784

757-
tool_id = str(tool.tool_id)
758-
759785
file_hash = ToolUtils.get_hash_from_file(file_path=doc_path)
760786

761787
payload = {
@@ -789,6 +815,36 @@ def _fetch_response(
789815
output_response = json.loads(answer["structure_output"])
790816
return output_response
791817

818+
@staticmethod
819+
def fetch_table_settings_if_enabled(
820+
doc_name: str,
821+
prompt: ToolStudioPrompt,
822+
org_id: str,
823+
user_id: str,
824+
tool_id: str,
825+
output: dict[str, Any],
826+
) -> dict[str, Any]:
827+
828+
if prompt.enforce_type == TSPKeys.TABLE:
829+
extract_doc_path: str = (
830+
PromptStudioHelper._get_extract_or_summary_document_path(
831+
org_id, user_id, tool_id, doc_name, TSPKeys.EXTRACT
832+
)
833+
)
834+
for modifier_plugin in modifier_loader:
835+
cls = modifier_plugin[ModifierConfig.METADATA][
836+
ModifierConfig.METADATA_SERVICE_CLASS
837+
]
838+
output = cls.update(
839+
output=output,
840+
tool_id=tool_id,
841+
prompt_id=str(prompt.prompt_id),
842+
prompt=prompt.prompt,
843+
input_file=extract_doc_path,
844+
)
845+
846+
return output
847+
792848
@staticmethod
793849
def dynamic_indexer(
794850
profile_manager: ProfileManager,

backend/prompt_studio/prompt_studio_core/static/select_choices.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"email":"email",
1414
"date":"date",
1515
"boolean":"boolean",
16-
"json":"json"
16+
"json":"json",
17+
"table":"table"
1718
},
1819
"output_processing":{
1920
"DEFAULT":"Default"

backend/prompt_studio/prompt_studio_registry/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class PromptStudioRegistryKeys:
99
PROMPT_REGISTRY_ID = "prompt_registry_id"
1010
FILE_NAME = "file_name"
1111
UNDEFINED = "undefined"
12+
TABLE = "table"
1213

1314

1415
class PromptStudioRegistryErrors:

backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
from adapter_processor.models import AdapterInstance
66
from django.conf import settings
77
from django.db import IntegrityError
8+
from prompt_studio.modifier_loader import ModifierConfig
9+
from prompt_studio.modifier_loader import load_plugins as load_modifier_plugins
810
from prompt_studio.prompt_profile_manager.models import ProfileManager
911
from prompt_studio.prompt_studio.models import ToolStudioPrompt
1012
from prompt_studio.prompt_studio_core.models import CustomTool
1113
from prompt_studio.prompt_studio_core.prompt_studio_helper import PromptStudioHelper
1214
from prompt_studio.prompt_studio_output_manager.models import PromptStudioOutputManager
1315
from unstract.tool_registry.dto import Properties, Spec, Tool
1416

15-
from .constants import JsonSchemaKey
17+
from .constants import JsonSchemaKey, PromptStudioRegistryKeys
1618
from .exceptions import (
1719
EmptyToolExportError,
1820
InternalError,
@@ -23,6 +25,7 @@
2325
from .serializers import PromptStudioRegistrySerializer
2426

2527
logger = logging.getLogger(__name__)
28+
modifier_loader = load_modifier_plugins()
2629

2730

2831
class PromptStudioRegistryHelper:
@@ -320,6 +323,19 @@ def frame_export_json(
320323
output[JsonSchemaKey.SECTION] = prompt.profile_manager.section
321324
output[JsonSchemaKey.REINDEX] = prompt.profile_manager.reindex
322325
output[JsonSchemaKey.EMBEDDING_SUFFIX] = embedding_suffix
326+
327+
if prompt.enforce_type == PromptStudioRegistryKeys.TABLE:
328+
for modifier_plugin in modifier_loader:
329+
cls = modifier_plugin[ModifierConfig.METADATA][
330+
ModifierConfig.METADATA_SERVICE_CLASS
331+
]
332+
output = cls.update(
333+
output=output,
334+
tool_id=tool.tool_id,
335+
prompt_id=prompt.prompt_id,
336+
prompt=prompt.prompt,
337+
)
338+
323339
outputs.append(output)
324340
output = {}
325341
vector_db = ""

0 commit comments

Comments
 (0)