Skip to content

Commit 117e118

Browse files
Adding example for using prediction API for Search Index Chunking (#98)
* Adding example for using prediction API
1 parent 2e974af commit 117e118

3 files changed

Lines changed: 289 additions & 0 deletions

File tree

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"entryPoint": "entrypoint.py"
3+
}
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) 2025, Salesforce, Inc.
3+
# SPDX-License-Identifier: Apache-2
4+
5+
"""
6+
Housing Sale Price Prediction with Einstein Regression
7+
8+
This example uses Einstein regression model to predict housing sale prices
9+
based on property features like Year_Built__c.
10+
11+
Model: YH_Regression_Python_Predicted_SalePrice_CM_12l_ATC937af934
12+
Type: Regression
13+
Input: Year_Built__c (numeric)
14+
Output: Predicted_SalePrice
15+
"""
16+
17+
import logging
18+
from typing import (
19+
Any,
20+
Dict,
21+
Optional,
22+
)
23+
24+
from datacustomcode.einstein_predictions.types import (
25+
PredictionColumBuilder,
26+
PredictionRequestBuilder,
27+
PredictionType,
28+
)
29+
from datacustomcode.function import Runtime
30+
from datacustomcode.function.feature_types.chunking import (
31+
ChunkType,
32+
SearchIndexChunkingV1Output,
33+
SearchIndexChunkingV1Request,
34+
SearchIndexChunkingV1Response,
35+
)
36+
37+
logger = logging.getLogger(__name__)
38+
logging.basicConfig(level=logging.INFO)
39+
40+
# Configuration
41+
PREDICTION_MODEL_NAME = "YH_Regression_Python_Predicted_SalePrice_CM_12l_ATC937af934"
42+
43+
44+
def predict_sale_price(
45+
features: Dict[str, Any],
46+
runtime: Runtime,
47+
) -> Optional[float]:
48+
"""Predict housing sale price using Einstein regression model.
49+
50+
Args:
51+
features: Extracted housing features (numeric and string)
52+
runtime: Runtime with prediction client
53+
54+
Returns:
55+
Predicted sale price or None if prediction fails
56+
"""
57+
try:
58+
# Build prediction columns - handle both numeric and string values
59+
prediction_columns = []
60+
61+
for column_name, value in features.items():
62+
if isinstance(value, str):
63+
# String values (e.g., Garage_Qual__c)
64+
column = (
65+
PredictionColumBuilder()
66+
.set_column_name(column_name)
67+
.set_string_values([value])
68+
.build()
69+
)
70+
elif isinstance(value, (int, float)):
71+
# Numeric values
72+
column = (
73+
PredictionColumBuilder()
74+
.set_column_name(column_name)
75+
.set_double_values([float(value)])
76+
.build()
77+
)
78+
else:
79+
# Skip unsupported types
80+
logger.warning(
81+
f"Skipping field {column_name} with unsupported type {type(value)}"
82+
)
83+
continue
84+
85+
prediction_columns.append(column)
86+
87+
# Build regression prediction request
88+
prediction_request = (
89+
PredictionRequestBuilder()
90+
.set_prediction_type(PredictionType.REGRESSION)
91+
.set_model_api_name(PREDICTION_MODEL_NAME)
92+
.set_prediction_columns(prediction_columns)
93+
.build()
94+
)
95+
96+
prediction_response = runtime.einstein_predictions.predict(prediction_request)
97+
98+
if not prediction_response.is_success:
99+
logger.error(f"Prediction failed: {prediction_response.data}")
100+
return None
101+
102+
# Parse regression response
103+
if prediction_response.data is None:
104+
logger.warning("Prediction response data is None")
105+
return None
106+
107+
results = prediction_response.data.get("results", [])
108+
if not results:
109+
logger.warning("No results in prediction response")
110+
return None
111+
112+
first_result = results[0]
113+
prediction_type = first_result.get("type")
114+
115+
if prediction_type != "RegressionPredictionSuccess":
116+
logger.error(f"Unexpected prediction type: {prediction_type}")
117+
logger.error(f"Full result: {first_result}")
118+
return None
119+
120+
prediction_data = first_result.get("prediction", {})
121+
predicted_value = prediction_data.get("value")
122+
123+
if predicted_value is None:
124+
logger.warning("No predicted value in response")
125+
return None
126+
127+
logger.info(f"Predicted sale price: ${predicted_value:,.2f}")
128+
129+
# Log top contributors (which features influenced the price most)
130+
top_contributors = prediction_data.get("topContributors", [])
131+
if top_contributors:
132+
logger.info(f"Top price contributors: {top_contributors}")
133+
134+
return float(predicted_value)
135+
136+
except Exception as e:
137+
logger.error(f"Prediction failed with error: {e}", exc_info=True)
138+
return None
139+
140+
141+
def enrich_property_with_price(
142+
source_dmo_fields: Dict[str, Any],
143+
runtime: Runtime,
144+
) -> Dict[str, str]:
145+
"""Enrich property data with predicted sale price.
146+
147+
Args:
148+
source_dmo_fields: Property features from source DMO
149+
runtime: Runtime for predictions
150+
151+
Returns:
152+
Citations dictionary with predicted price
153+
"""
154+
citations = {}
155+
156+
# Copy original fields to citations
157+
if source_dmo_fields:
158+
for key, value in source_dmo_fields.items():
159+
citations[key] = str(value)
160+
161+
# Get price prediction - pass source_dmo_fields directly as features
162+
predicted_price = predict_sale_price(source_dmo_fields, runtime)
163+
164+
if predicted_price is not None:
165+
citations["predicted_sale_price"] = f"${predicted_price:,.2f}"
166+
citations["predicted_sale_price_raw"] = str(predicted_price)
167+
citations["prediction_status"] = "success"
168+
else:
169+
citations["predicted_sale_price"] = "N/A"
170+
citations["prediction_status"] = "failed"
171+
172+
return citations
173+
174+
175+
def function(
176+
request: SearchIndexChunkingV1Request, runtime: Runtime
177+
) -> SearchIndexChunkingV1Response:
178+
"""Housing price prediction using Einstein regression.
179+
180+
Predicts sale prices for properties based on Year_Built__c feature
181+
and adds predictions to citations for real estate data enrichment.
182+
183+
Input format:
184+
{
185+
"input": [
186+
{
187+
"text": "Beautiful 3BR house built in 1990",
188+
"metadata": {
189+
"source_dmo_fields": {
190+
"Year_Built__c": 1990,
191+
}
192+
}
193+
}
194+
]
195+
}
196+
197+
Output format:
198+
{
199+
"output": [
200+
{
201+
"text": "Beautiful 3BR house built in 1990",
202+
"seq_no": 1,
203+
"citations": {
204+
"Year_Built__c": "1990",
205+
"predicted_sale_price": "$350,000.00",
206+
"predicted_sale_price_raw": "350000.0",
207+
"prediction_status": "success"
208+
}
209+
}
210+
]
211+
}
212+
213+
Args:
214+
request: Input properties to enrich
215+
runtime: Runtime with prediction API access
216+
217+
Returns:
218+
Properties enriched with predicted sale prices
219+
"""
220+
221+
enriched_properties = []
222+
seq_no = 1
223+
224+
for doc_idx, doc in enumerate(request.input):
225+
text = doc.text
226+
metadata = doc.metadata
227+
228+
source_dmo_fields = {}
229+
if metadata and metadata.source_dmo_fields:
230+
source_dmo_fields = dict(metadata.source_dmo_fields)
231+
232+
# Enrich with price prediction - pass source_dmo_fields directly
233+
citations = enrich_property_with_price(source_dmo_fields, runtime)
234+
235+
# Create output
236+
property_output = SearchIndexChunkingV1Output(
237+
chunk_type=ChunkType.TEXT,
238+
text=text.strip(),
239+
seq_no=seq_no,
240+
citations=citations,
241+
)
242+
enriched_properties.append(property_output)
243+
244+
seq_no += 1
245+
246+
return SearchIndexChunkingV1Response(output=enriched_properties)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
{
2+
"input": [
3+
{
4+
"text": "Luxury 5-bedroom house with 5000 sq ft living area, built in 2023",
5+
"metadata": {
6+
"type": "text",
7+
"source_dmo_fields": {
8+
"First_Flr_SF__c": 2600,
9+
"Full_Bath__c": 4,
10+
"Garage_Cars__c": 4,
11+
"Garage_Qual__c": "good",
12+
"Gr_Liv_Area__c": 5000,
13+
"Lot_Area__c": 3000,
14+
"Overall_Cond__c": 10,
15+
"Second_Flr_SF__c": 2400,
16+
"Total_Bsmt_SF__c": 0,
17+
"Year_Built__c": 2023
18+
}
19+
}
20+
},
21+
{
22+
"text": "Spacious 4-bedroom family home with 3500 sq ft living space, built in 2020",
23+
"metadata": {
24+
"type": "text",
25+
"source_dmo_fields": {
26+
"First_Flr_SF__c": 2000,
27+
"Full_Bath__c": 3,
28+
"Garage_Cars__c": 3,
29+
"Garage_Qual__c": "excellent",
30+
"Gr_Liv_Area__c": 3500,
31+
"Lot_Area__c": 8000,
32+
"Overall_Cond__c": 9,
33+
"Second_Flr_SF__c": 1500,
34+
"Total_Bsmt_SF__c": 1000,
35+
"Year_Built__c": 2020
36+
}
37+
}
38+
}
39+
]
40+
}

0 commit comments

Comments
 (0)