Skip to content

Commit 5501972

Browse files
[Fix] Prompt studio Coverage (#907)
* fixes for prompt sudio coverage * fixed prompt studio local variable issue * updated return type * added optional chaining * added missing type * added comment for error suppression * handled edge cases
1 parent 84497ff commit 5501972

File tree

6 files changed

+110
-83
lines changed

6 files changed

+110
-83
lines changed

backend/prompt_studio/prompt_studio_core_v2/serializers.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,48 +44,71 @@ class Meta:
4444

4545
def to_representation(self, instance): # type: ignore
4646
data = super().to_representation(instance)
47+
default_profile = None
48+
49+
# Fetch summarize LLM profile
4750
try:
48-
profile_manager = ProfileManager.objects.get(
51+
summarize_profile = ProfileManager.objects.get(
4952
prompt_studio_tool=instance, is_summarize_llm=True
5053
)
51-
data[TSKeys.SUMMARIZE_LLM_PROFILE] = profile_manager.profile_id
54+
data[TSKeys.SUMMARIZE_LLM_PROFILE] = summarize_profile.profile_id
5255
except ObjectDoesNotExist:
5356
logger.info(
54-
"Summarize LLM profile doesnt exist for prompt tool %s",
57+
"Summarize LLM profile doesn't exist for prompt tool %s",
5558
str(instance.tool_id),
5659
)
60+
61+
# Fetch default LLM profile
5762
try:
58-
profile_manager = ProfileManager.get_default_llm_profile(instance)
59-
data[TSKeys.DEFAULT_PROFILE] = profile_manager.profile_id
63+
default_profile = ProfileManager.get_default_llm_profile(instance)
64+
data[TSKeys.DEFAULT_PROFILE] = default_profile.profile_id
6065
except DefaultProfileError:
66+
# To make it compatible with older projects error suppressed with warning.
6167
logger.warning(
62-
"Default LLM profile doesnt exist for prompt tool %s",
68+
"Default LLM profile doesn't exist for prompt tool %s",
6369
str(instance.tool_id),
6470
)
65-
prompt_instance: ToolStudioPrompt = ToolStudioPrompt.objects.filter(
71+
72+
# Fetch prompt instances
73+
prompt_instances: ToolStudioPrompt = ToolStudioPrompt.objects.filter(
6674
tool_id=data.get(TSKeys.TOOL_ID)
6775
).order_by("sequence_number")
68-
data[TSKeys.PROMPTS] = []
76+
77+
if not prompt_instances.exists():
78+
data[TSKeys.PROMPTS] = []
79+
return data
80+
81+
# Process prompt instances
6982
output: list[Any] = []
70-
# Appending prompt instances of the tool for FE Processing
71-
if prompt_instance.count() != 0:
72-
for prompt in prompt_instance:
73-
profile_manager_id = prompt.prompt_id
74-
if instance.single_pass_extraction_mode:
75-
# use projects default profile
76-
profile_manager_id = profile_manager.profile_id
77-
prompt_serializer = ToolStudioPromptSerializer(prompt)
83+
for prompt in prompt_instances:
84+
prompt_serializer = ToolStudioPromptSerializer(prompt)
85+
serialized_data = prompt_serializer.data
86+
87+
# Determine coverage
88+
coverage: list[Any] = []
89+
profile_manager_id = prompt.profile_manager
90+
if default_profile and instance.single_pass_extraction_mode:
91+
profile_manager_id = default_profile.profile_id
92+
93+
if profile_manager_id:
7894
coverage = OutputManagerUtils.get_coverage(
7995
data.get(TSKeys.TOOL_ID),
8096
profile_manager_id,
8197
prompt.prompt_id,
8298
instance.single_pass_extraction_mode,
8399
)
84-
serialized_data = prompt_serializer.data
85-
serialized_data["coverage"] = coverage
86-
output.append(serialized_data)
87-
data[TSKeys.PROMPTS] = output
100+
else:
101+
logger.info(
102+
"Skipping coverage calculation for prompt %s "
103+
"due to missing profile ID",
104+
str(prompt.prompt_key),
105+
)
106+
107+
# Add coverage to serialized data
108+
serialized_data["coverage"] = coverage
109+
output.append(serialized_data)
88110

111+
data[TSKeys.PROMPTS] = output
89112
data["created_by_email"] = instance.created_by.email
90113

91114
return data
Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from django.db.models import Count
21
from prompt_studio.prompt_studio_output_manager_v2.models import (
32
PromptStudioOutputManager,
43
)
@@ -11,41 +10,33 @@ def get_coverage(
1110
profile_manager_id: str,
1211
prompt_id: str = None,
1312
is_single_pass: bool = False,
14-
) -> dict[str, int]:
13+
) -> list[str]:
1514
"""
1615
Method to fetch coverage data for given tool and profile manager.
1716
1817
Args:
19-
tool (CustomTool): The tool instance or ID for which coverage is fetched.
18+
tool_id (str): The ID of the tool for which coverage is fetched.
2019
profile_manager_id (str): The ID of the profile manager
2120
for which coverage is calculated.
2221
prompt_id (Optional[str]): The ID of the prompt (optional).
23-
is_single_pass (Optional[bool]): Singlepass enabled or not
22+
is_single_pass (Optional[bool]): Singlepass enabled or not.
2423
If provided, coverage is fetched for the specific prompt.
2524
2625
Returns:
27-
dict[str, int]: A dictionary containing coverage information.
26+
dict[str, list[str]]: A dictionary containing coverage information.
2827
Keys are formatted as "coverage_<prompt_id>_<profile_manager_id>".
29-
Values are the count of documents associated with each prompt
28+
Values are lists of document IDs associated with each prompt
3029
and profile combination.
3130
"""
3231
# TODO: remove singlepass reference
33-
prompt_outputs = (
34-
PromptStudioOutputManager.objects.filter(
35-
tool_id=tool_id,
36-
profile_manager_id=profile_manager_id,
37-
prompt_id=prompt_id,
38-
is_single_pass_extract=is_single_pass,
39-
)
40-
.values("prompt_id", "profile_manager_id")
41-
.annotate(document_count=Count("document_manager_id"))
42-
)
32+
prompt_outputs = PromptStudioOutputManager.objects.filter(
33+
tool_id=tool_id,
34+
profile_manager_id=profile_manager_id,
35+
prompt_id=prompt_id,
36+
is_single_pass_extract=is_single_pass,
37+
).values("prompt_id", "profile_manager_id", "document_manager_id")
4338

44-
coverage = {}
39+
coverage = []
4540
for prompt_output in prompt_outputs:
46-
prompt_key = str(prompt_output["prompt_id"])
47-
profile_key = str(prompt_output["profile_manager_id"])
48-
coverage[f"coverage_{prompt_key}_{profile_key}"] = prompt_output[
49-
"document_count"
50-
]
41+
coverage.append(str(prompt_output["document_manager_id"]))
5142
return coverage

frontend/src/components/custom-tools/document-parser/DocumentParser.jsx

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -180,29 +180,6 @@ function DocumentParser({
180180
return outputs;
181181
};
182182

183-
const getPromptCoverageCount = (promptId) => {
184-
const keys = Object.keys(promptOutputs || {});
185-
const coverageKey = `coverage_${promptId}`;
186-
const outputs = {};
187-
if (!keys?.length) {
188-
details?.prompts?.forEach((prompt) => {
189-
if (prompt?.coverage) {
190-
const key = Object.keys(prompt?.coverage)[0];
191-
if (key?.startsWith(coverageKey)) {
192-
outputs[key] = prompt?.coverage[key];
193-
}
194-
}
195-
});
196-
return outputs;
197-
}
198-
keys?.forEach((key) => {
199-
if (key?.startsWith(coverageKey)) {
200-
outputs[key] = promptOutputs[key];
201-
}
202-
});
203-
return outputs;
204-
};
205-
206183
if (!details?.prompts?.length) {
207184
if (isSimplePromptStudio && SpsPromptsEmptyState) {
208185
return <SpsPromptsEmptyState />;
@@ -230,7 +207,7 @@ function DocumentParser({
230207
outputs={getPromptOutputs(item?.prompt_id)}
231208
enforceTypeList={enforceTypeList}
232209
setUpdatedPromptsCopy={setUpdatedPromptsCopy}
233-
coverageCountData={getPromptCoverageCount(item?.prompt_id)}
210+
coverageCountData={item?.coverage}
234211
isChallenge={isChallenge}
235212
/>
236213
<div ref={bottomRef} className="doc-parser-pad-bottom" />

frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import SpaceWrapper from "../../widgets/space-wrapper/SpaceWrapper";
3333
import { SpinnerLoader } from "../../widgets/spinner-loader/SpinnerLoader";
3434
import "./ManageDocsModal.css";
3535
import usePostHogEvents from "../../../hooks/usePostHogEvents";
36+
import { usePromptOutputStore } from "../../../store/prompt-output-store";
3637

3738
let SummarizeStatusTitle = null;
3839
try {
@@ -90,6 +91,7 @@ function ManageDocsModal({
9091
const axiosPrivate = useAxiosPrivate();
9192
const handleException = useExceptionHandler();
9293
const { setPostHogCustomEvent } = usePostHogEvents();
94+
const { promptOutputs, updatePromptOutput } = usePromptOutputStore();
9395

9496
const successIndex = (
9597
<Typography.Text>
@@ -543,21 +545,48 @@ function ManageDocsModal({
543545
);
544546
updateCustomTool({ listOfDocs: newListOfDocs });
545547

546-
if (newListOfDocs?.length === 1 && selectedDoc?.document_id !== docId) {
547-
const doc = newListOfDocs[1];
548+
if (selectedDoc?.document_id === docId) {
549+
const doc = newListOfDocs[0];
548550
handleDocChange(doc);
549551
}
550-
551-
if (docId === selectedDoc?.document_id) {
552-
updateCustomTool({ selectedDoc: "" });
553-
handleUpdateTool({ output: "" });
554-
}
552+
const updatedPromptDetails = removeIdFromCoverage(details, docId);
553+
const updatedPromptOutput = removeIdFromCoverageOfPromptOutput(
554+
promptOutputs,
555+
docId
556+
);
557+
updateCustomTool({ details: updatedPromptDetails });
558+
updatePromptOutput(updatedPromptOutput);
555559
})
556560
.catch((err) => {
557561
setAlertDetails(handleException(err, "Failed to delete"));
558562
});
559563
};
560564

565+
const removeIdFromCoverage = (data, idToRemove) => {
566+
if (data.prompts && Array.isArray(data.prompts)) {
567+
data.prompts.forEach((prompt) => {
568+
if (Array.isArray(prompt.coverage)) {
569+
prompt.coverage = prompt.coverage.filter((id) => id !== idToRemove);
570+
}
571+
});
572+
}
573+
return data; // Return the updated data
574+
};
575+
576+
const removeIdFromCoverageOfPromptOutput = (data, idToRemove) => {
577+
return Object.entries(data).reduce((updatedData, [key, value]) => {
578+
// Create a new object for the current entry
579+
updatedData[key] = {
580+
...value,
581+
// Update the coverage array if it exists
582+
coverage: value?.coverage
583+
? value?.coverage?.filter((id) => id !== idToRemove)
584+
: value?.coverage,
585+
};
586+
return updatedData;
587+
}, {});
588+
};
589+
561590
return (
562591
<Modal
563592
className="pre-post-amble-modal"

frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import { Header } from "./Header";
1919
import { OutputForIndex } from "./OutputForIndex";
2020
import { PromptOutput } from "./PromptOutput";
2121
import { TABLE_ENFORCE_TYPE, RECORD_ENFORCE_TYPE } from "./constants";
22-
import { generateCoverageKey } from "../../../helpers/GetStaticData";
22+
import usePromptOutput from "../../../hooks/usePromptOutput";
2323

2424
let TableExtractionSettingsBtn;
2525
try {
@@ -66,6 +66,8 @@ function PromptCardItems({
6666
defaultLlmProfile,
6767
singlePassExtractMode,
6868
} = useCustomToolStore();
69+
70+
const { generatePromptOutputKey } = usePromptOutput();
6971
const [isEditingPrompt, setIsEditingPrompt] = useState(false);
7072
const [isEditingTitle, setIsEditingTitle] = useState(false);
7173
const [expandCard, setExpandCard] = useState(true);
@@ -78,10 +80,18 @@ function PromptCardItems({
7880
const isNotSingleLlmProfile = llmProfiles.length > 1;
7981
const divRef = useRef(null);
8082
const [enforceType, setEnforceType] = useState("");
81-
const profileId = singlePassExtractMode
82-
? defaultLlmProfile
83-
: selectedLlmProfileId || defaultLlmProfile;
84-
const coverageKey = generateCoverageKey(promptDetails?.prompt_id, profileId);
83+
const promptId = promptDetails?.prompt_id;
84+
const docId = selectedDoc?.document_id;
85+
const promptProfile = promptDetails?.profile_manager || defaultLlmProfile;
86+
const promptOutputKey = generatePromptOutputKey(
87+
promptId,
88+
docId,
89+
promptProfile,
90+
singlePassExtractMode,
91+
true
92+
);
93+
const promptCoverage =
94+
promptOutputs[promptOutputKey]?.coverage || coverageCountData;
8595

8696
useEffect(() => {
8797
if (enforceType !== promptDetails?.enforce_type) {
@@ -213,7 +223,7 @@ function PromptCardItems({
213223
<SearchOutlined className="font-size-12" />
214224
)}
215225
<Typography.Link className="font-size-12">
216-
Coverage: {coverageCountData[coverageKey] || 0} of{" "}
226+
Coverage: {promptCoverage?.length || 0} of{" "}
217227
{listOfDocs?.length || 0} docs
218228
</Typography.Link>
219229
</Space>

frontend/src/hooks/usePromptOutput.js

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ const usePromptOutput = () => {
9191

9292
let isTokenUsageForSinglePassAdded = false;
9393
const tokenUsageDetails = {};
94-
9594
data.forEach((item) => {
9695
const promptId = item?.prompt_id;
9796
const docId = item?.document_manager;
@@ -109,7 +108,6 @@ const usePromptOutput = () => {
109108
isSinglePass,
110109
true
111110
);
112-
const coverageKey = `coverage_${item?.prompt_id}_${llmProfile}`;
113111
outputs[key] = {
114112
runId: item?.run_id,
115113
promptOutputId: item?.prompt_output_id,
@@ -119,8 +117,8 @@ const usePromptOutput = () => {
119117
tokenUsage: item?.token_usage,
120118
output: item?.output,
121119
timer,
120+
coverage: item?.coverage,
122121
};
123-
outputs[coverageKey] = item?.coverage[coverageKey] || 0;
124122

125123
if (item?.is_single_pass_extract && isTokenUsageForSinglePassAdded)
126124
return;
@@ -150,7 +148,6 @@ const usePromptOutput = () => {
150148
);
151149
tokenUsageDetails[tokenUsageId] = item?.token_usage;
152150
});
153-
154151
if (isReset) {
155152
setPromptOutput(outputs);
156153
setTokenUsage(tokenUsageDetails);

0 commit comments

Comments
 (0)