Skip to content

Commit

Permalink
[Fix] Prompt studio Coverage (#907)
Browse files Browse the repository at this point in the history
* fixes for prompt sudio coverage

* fixed prompt studio local variable issue

* updated return type

* added optional chaining

* added missing type

* added comment for error suppression

* handled edge cases
  • Loading branch information
jagadeeswaran-zipstack authored Dec 19, 2024
1 parent 84497ff commit 5501972
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 83 deletions.
63 changes: 43 additions & 20 deletions backend/prompt_studio/prompt_studio_core_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,48 +44,71 @@ class Meta:

def to_representation(self, instance): # type: ignore
data = super().to_representation(instance)
default_profile = None

# Fetch summarize LLM profile
try:
profile_manager = ProfileManager.objects.get(
summarize_profile = ProfileManager.objects.get(
prompt_studio_tool=instance, is_summarize_llm=True
)
data[TSKeys.SUMMARIZE_LLM_PROFILE] = profile_manager.profile_id
data[TSKeys.SUMMARIZE_LLM_PROFILE] = summarize_profile.profile_id
except ObjectDoesNotExist:
logger.info(
"Summarize LLM profile doesnt exist for prompt tool %s",
"Summarize LLM profile doesn't exist for prompt tool %s",
str(instance.tool_id),
)

# Fetch default LLM profile
try:
profile_manager = ProfileManager.get_default_llm_profile(instance)
data[TSKeys.DEFAULT_PROFILE] = profile_manager.profile_id
default_profile = ProfileManager.get_default_llm_profile(instance)
data[TSKeys.DEFAULT_PROFILE] = default_profile.profile_id
except DefaultProfileError:
# To make it compatible with older projects error suppressed with warning.
logger.warning(
"Default LLM profile doesnt exist for prompt tool %s",
"Default LLM profile doesn't exist for prompt tool %s",
str(instance.tool_id),
)
prompt_instance: ToolStudioPrompt = ToolStudioPrompt.objects.filter(

# Fetch prompt instances
prompt_instances: ToolStudioPrompt = ToolStudioPrompt.objects.filter(
tool_id=data.get(TSKeys.TOOL_ID)
).order_by("sequence_number")
data[TSKeys.PROMPTS] = []

if not prompt_instances.exists():
data[TSKeys.PROMPTS] = []
return data

# Process prompt instances
output: list[Any] = []
# Appending prompt instances of the tool for FE Processing
if prompt_instance.count() != 0:
for prompt in prompt_instance:
profile_manager_id = prompt.prompt_id
if instance.single_pass_extraction_mode:
# use projects default profile
profile_manager_id = profile_manager.profile_id
prompt_serializer = ToolStudioPromptSerializer(prompt)
for prompt in prompt_instances:
prompt_serializer = ToolStudioPromptSerializer(prompt)
serialized_data = prompt_serializer.data

# Determine coverage
coverage: list[Any] = []
profile_manager_id = prompt.profile_manager
if default_profile and instance.single_pass_extraction_mode:
profile_manager_id = default_profile.profile_id

if profile_manager_id:
coverage = OutputManagerUtils.get_coverage(
data.get(TSKeys.TOOL_ID),
profile_manager_id,
prompt.prompt_id,
instance.single_pass_extraction_mode,
)
serialized_data = prompt_serializer.data
serialized_data["coverage"] = coverage
output.append(serialized_data)
data[TSKeys.PROMPTS] = output
else:
logger.info(
"Skipping coverage calculation for prompt %s "
"due to missing profile ID",
str(prompt.prompt_key),
)

# Add coverage to serialized data
serialized_data["coverage"] = coverage
output.append(serialized_data)

data[TSKeys.PROMPTS] = output
data["created_by_email"] = instance.created_by.email

return data
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from django.db.models import Count
from prompt_studio.prompt_studio_output_manager_v2.models import (
PromptStudioOutputManager,
)
Expand All @@ -11,41 +10,33 @@ def get_coverage(
profile_manager_id: str,
prompt_id: str = None,
is_single_pass: bool = False,
) -> dict[str, int]:
) -> list[str]:
"""
Method to fetch coverage data for given tool and profile manager.
Args:
tool (CustomTool): The tool instance or ID for which coverage is fetched.
tool_id (str): The ID of the tool for which coverage is fetched.
profile_manager_id (str): The ID of the profile manager
for which coverage is calculated.
prompt_id (Optional[str]): The ID of the prompt (optional).
is_single_pass (Optional[bool]): Singlepass enabled or not
is_single_pass (Optional[bool]): Singlepass enabled or not.
If provided, coverage is fetched for the specific prompt.
Returns:
dict[str, int]: A dictionary containing coverage information.
dict[str, list[str]]: A dictionary containing coverage information.
Keys are formatted as "coverage_<prompt_id>_<profile_manager_id>".
Values are the count of documents associated with each prompt
Values are lists of document IDs associated with each prompt
and profile combination.
"""
# TODO: remove singlepass reference
prompt_outputs = (
PromptStudioOutputManager.objects.filter(
tool_id=tool_id,
profile_manager_id=profile_manager_id,
prompt_id=prompt_id,
is_single_pass_extract=is_single_pass,
)
.values("prompt_id", "profile_manager_id")
.annotate(document_count=Count("document_manager_id"))
)
prompt_outputs = PromptStudioOutputManager.objects.filter(
tool_id=tool_id,
profile_manager_id=profile_manager_id,
prompt_id=prompt_id,
is_single_pass_extract=is_single_pass,
).values("prompt_id", "profile_manager_id", "document_manager_id")

coverage = {}
coverage = []
for prompt_output in prompt_outputs:
prompt_key = str(prompt_output["prompt_id"])
profile_key = str(prompt_output["profile_manager_id"])
coverage[f"coverage_{prompt_key}_{profile_key}"] = prompt_output[
"document_count"
]
coverage.append(str(prompt_output["document_manager_id"]))
return coverage
Original file line number Diff line number Diff line change
Expand Up @@ -180,29 +180,6 @@ function DocumentParser({
return outputs;
};

const getPromptCoverageCount = (promptId) => {
const keys = Object.keys(promptOutputs || {});
const coverageKey = `coverage_${promptId}`;
const outputs = {};
if (!keys?.length) {
details?.prompts?.forEach((prompt) => {
if (prompt?.coverage) {
const key = Object.keys(prompt?.coverage)[0];
if (key?.startsWith(coverageKey)) {
outputs[key] = prompt?.coverage[key];
}
}
});
return outputs;
}
keys?.forEach((key) => {
if (key?.startsWith(coverageKey)) {
outputs[key] = promptOutputs[key];
}
});
return outputs;
};

if (!details?.prompts?.length) {
if (isSimplePromptStudio && SpsPromptsEmptyState) {
return <SpsPromptsEmptyState />;
Expand Down Expand Up @@ -230,7 +207,7 @@ function DocumentParser({
outputs={getPromptOutputs(item?.prompt_id)}
enforceTypeList={enforceTypeList}
setUpdatedPromptsCopy={setUpdatedPromptsCopy}
coverageCountData={getPromptCoverageCount(item?.prompt_id)}
coverageCountData={item?.coverage}
isChallenge={isChallenge}
/>
<div ref={bottomRef} className="doc-parser-pad-bottom" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import SpaceWrapper from "../../widgets/space-wrapper/SpaceWrapper";
import { SpinnerLoader } from "../../widgets/spinner-loader/SpinnerLoader";
import "./ManageDocsModal.css";
import usePostHogEvents from "../../../hooks/usePostHogEvents";
import { usePromptOutputStore } from "../../../store/prompt-output-store";

let SummarizeStatusTitle = null;
try {
Expand Down Expand Up @@ -90,6 +91,7 @@ function ManageDocsModal({
const axiosPrivate = useAxiosPrivate();
const handleException = useExceptionHandler();
const { setPostHogCustomEvent } = usePostHogEvents();
const { promptOutputs, updatePromptOutput } = usePromptOutputStore();

const successIndex = (
<Typography.Text>
Expand Down Expand Up @@ -543,21 +545,48 @@ function ManageDocsModal({
);
updateCustomTool({ listOfDocs: newListOfDocs });

if (newListOfDocs?.length === 1 && selectedDoc?.document_id !== docId) {
const doc = newListOfDocs[1];
if (selectedDoc?.document_id === docId) {
const doc = newListOfDocs[0];
handleDocChange(doc);
}

if (docId === selectedDoc?.document_id) {
updateCustomTool({ selectedDoc: "" });
handleUpdateTool({ output: "" });
}
const updatedPromptDetails = removeIdFromCoverage(details, docId);
const updatedPromptOutput = removeIdFromCoverageOfPromptOutput(
promptOutputs,
docId
);
updateCustomTool({ details: updatedPromptDetails });
updatePromptOutput(updatedPromptOutput);
})
.catch((err) => {
setAlertDetails(handleException(err, "Failed to delete"));
});
};

const removeIdFromCoverage = (data, idToRemove) => {
if (data.prompts && Array.isArray(data.prompts)) {
data.prompts.forEach((prompt) => {
if (Array.isArray(prompt.coverage)) {
prompt.coverage = prompt.coverage.filter((id) => id !== idToRemove);
}
});
}
return data; // Return the updated data
};

const removeIdFromCoverageOfPromptOutput = (data, idToRemove) => {
return Object.entries(data).reduce((updatedData, [key, value]) => {
// Create a new object for the current entry
updatedData[key] = {
...value,
// Update the coverage array if it exists
coverage: value?.coverage
? value?.coverage?.filter((id) => id !== idToRemove)
: value?.coverage,
};
return updatedData;
}, {});
};

return (
<Modal
className="pre-post-amble-modal"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { Header } from "./Header";
import { OutputForIndex } from "./OutputForIndex";
import { PromptOutput } from "./PromptOutput";
import { TABLE_ENFORCE_TYPE, RECORD_ENFORCE_TYPE } from "./constants";
import { generateCoverageKey } from "../../../helpers/GetStaticData";
import usePromptOutput from "../../../hooks/usePromptOutput";

let TableExtractionSettingsBtn;
try {
Expand Down Expand Up @@ -66,6 +66,8 @@ function PromptCardItems({
defaultLlmProfile,
singlePassExtractMode,
} = useCustomToolStore();

const { generatePromptOutputKey } = usePromptOutput();
const [isEditingPrompt, setIsEditingPrompt] = useState(false);
const [isEditingTitle, setIsEditingTitle] = useState(false);
const [expandCard, setExpandCard] = useState(true);
Expand All @@ -78,10 +80,18 @@ function PromptCardItems({
const isNotSingleLlmProfile = llmProfiles.length > 1;
const divRef = useRef(null);
const [enforceType, setEnforceType] = useState("");
const profileId = singlePassExtractMode
? defaultLlmProfile
: selectedLlmProfileId || defaultLlmProfile;
const coverageKey = generateCoverageKey(promptDetails?.prompt_id, profileId);
const promptId = promptDetails?.prompt_id;
const docId = selectedDoc?.document_id;
const promptProfile = promptDetails?.profile_manager || defaultLlmProfile;
const promptOutputKey = generatePromptOutputKey(
promptId,
docId,
promptProfile,
singlePassExtractMode,
true
);
const promptCoverage =
promptOutputs[promptOutputKey]?.coverage || coverageCountData;

useEffect(() => {
if (enforceType !== promptDetails?.enforce_type) {
Expand Down Expand Up @@ -213,7 +223,7 @@ function PromptCardItems({
<SearchOutlined className="font-size-12" />
)}
<Typography.Link className="font-size-12">
Coverage: {coverageCountData[coverageKey] || 0} of{" "}
Coverage: {promptCoverage?.length || 0} of{" "}
{listOfDocs?.length || 0} docs
</Typography.Link>
</Space>
Expand Down
5 changes: 1 addition & 4 deletions frontend/src/hooks/usePromptOutput.js
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ const usePromptOutput = () => {

let isTokenUsageForSinglePassAdded = false;
const tokenUsageDetails = {};

data.forEach((item) => {
const promptId = item?.prompt_id;
const docId = item?.document_manager;
Expand All @@ -109,7 +108,6 @@ const usePromptOutput = () => {
isSinglePass,
true
);
const coverageKey = `coverage_${item?.prompt_id}_${llmProfile}`;
outputs[key] = {
runId: item?.run_id,
promptOutputId: item?.prompt_output_id,
Expand All @@ -119,8 +117,8 @@ const usePromptOutput = () => {
tokenUsage: item?.token_usage,
output: item?.output,
timer,
coverage: item?.coverage,
};
outputs[coverageKey] = item?.coverage[coverageKey] || 0;

if (item?.is_single_pass_extract && isTokenUsageForSinglePassAdded)
return;
Expand Down Expand Up @@ -150,7 +148,6 @@ const usePromptOutput = () => {
);
tokenUsageDetails[tokenUsageId] = item?.token_usage;
});

if (isReset) {
setPromptOutput(outputs);
setTokenUsage(tokenUsageDetails);
Expand Down

0 comments on commit 5501972

Please sign in to comment.