-
Notifications
You must be signed in to change notification settings - Fork 0
/
summarize.py
107 lines (94 loc) · 3.37 KB
/
summarize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import boto3
from griptape.drivers import (
AmazonRedshiftSqlDriver,
AmazonBedrockPromptDriver,
BedrockClaudePromptModelDriver,
BedrockTitanPromptModelDriver,
BedrockTitanEmbeddingDriver,
)
from griptape.loaders import SqlLoader
from griptape.memory import TaskMemory
from griptape.rules import Ruleset, Rule
from griptape.structures import Agent
from griptape.tools import SqlClient, FileManager, TaskMemoryClient
from griptape.artifacts import TextArtifact, BlobArtifact
from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage
from griptape.drivers import LocalVectorStoreDriver
from griptape.engines import (
VectorQueryEngine,
PromptSummaryEngine,
CsvExtractionEngine,
JsonExtractionEngine,
)
from dotenv import load_dotenv
# By default read the .env file
load_dotenv()
session = boto3.Session(region_name=os.environ["AWS_REGION"])
sql_loader = SqlLoader(
sql_driver=AmazonRedshiftSqlDriver(
database=os.environ["REDSHIFT_DATABASE"],
session=session,
database_credentials_secret_arn=os.environ["REDSHIFT_CREDENTIALS_SECRETS_MANAGER_ARN"],
workgroup_name=os.getenv("REDSHIFT_WORKGROUP_NAME"),
)
)
sql_tool = SqlClient(
sql_loader=sql_loader,
table_name="people",
table_description="contains information about tech industry professionals",
engine_name="redshift",
)
file_manager = FileManager()
task_memory_client = TaskMemoryClient(off_prompt=True)
ruleset = Ruleset(
name="HumansOrg Agent",
rules=[
Rule("Act and introduce yourself as a HumansOrg, Inc. support agent"),
Rule("Your main objective is to help with finding information about people"),
Rule("Only use information about people from the sources available to you"),
],
)
prompt_driver = AmazonBedrockPromptDriver(
model="anthropic.claude-v2",
prompt_model_driver=BedrockClaudePromptModelDriver(),
session=session,
)
task_memory_prompt_driver = AmazonBedrockPromptDriver(
model="amazon.titan-text-express-v1",
prompt_model_driver=BedrockTitanPromptModelDriver(),
session=session,
)
task_memory_embedding_driver = BedrockTitanEmbeddingDriver(session=session)
agent = Agent(
tools=[sql_tool, file_manager, task_memory_client],
rulesets=[ruleset],
prompt_driver=prompt_driver,
embedding_driver=task_memory_embedding_driver,
task_memory=TaskMemory(
artifact_storages={
TextArtifact: TextArtifactStorage(
query_engine=VectorQueryEngine(
prompt_driver=task_memory_prompt_driver,
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=task_memory_embedding_driver
),
),
summary_engine=PromptSummaryEngine(
prompt_driver=task_memory_prompt_driver
),
csv_extraction_engine=CsvExtractionEngine(
prompt_driver=task_memory_prompt_driver
),
json_extraction_engine=JsonExtractionEngine(
prompt_driver=task_memory_prompt_driver
),
),
BlobArtifact: BlobArtifactStorage(),
}
),
)
agent.run(
"Summarize a report of tech industry professional's names and occupations "
"and save to the current directory in a file called occupations.txt"
)