Skip to content

Commit

Permalink
allow uses to delete their memories
Browse files Browse the repository at this point in the history
  • Loading branch information
lbr88 committed Dec 26, 2024
1 parent 73fbc09 commit 62bcdc5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 5 deletions.
34 changes: 33 additions & 1 deletion plugins/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,17 @@ def initialize(
needs_message_object=True,
returns_files=False,
)
delete_user_memory_or_memories = Tool(
function=self.delete_user_memory_or_memories,
description="Delete a memory or memories for the user. If you want to delete a specific memory or memories for the user, activate this tool to delete the memory or memories. This is useful for when you want to delete a specific memory or memories for a specific user.",
parameters=[
{"name": "memory_id", "required": False, "description": "The memory or memories to delete for the user. This can be a memory id"},
{"name": "mode", "required": True, "description": "The mode to delete the memory or memories. allowed values: ALL, ID or CONTEXT. If id is used, the memory_id parameter must be set, all will delete ALL memories created by the user in all contexts. context will delete all memories for the user in the current context"},
{"name": "tool_run", "required": True, "description": "this must always be true"}
],
needs_message_object=True,
returns_files=False,
)
enable_disable_memories = Tool(
function=self.enable_disable_memories,
description="GLOBALLY Enable/Disable Memories for the bot. This will enable or disable the bot from using memories for the bot. This is useful for when you want to disable memories for a specific context.",
Expand Down Expand Up @@ -361,6 +372,7 @@ def initialize(
self.tools_manager.add_tool(search_user_memories)
#self.tools_manager.add_tool(enable_disable_memories)
self.tools_manager.add_tool(enable_disable_memories_user)
self.tools_manager.add_tool(delete_user_memory_or_memories)
self.user_tools = self.tools_manager.get_tools_as_dict("user")
self.admin_tools = self.tools_manager.get_tools_as_dict("admin")
# print the tools
Expand All @@ -370,7 +382,27 @@ def initialize(
self.helper.slog(
"Admin tools: " + ", ".join(self.tools_manager.get_tools("admin").keys())
)

async def delete_user_memory_or_memories(self, message: Message, mode: str, memory_id: str = None, tool_run=False):
"""Delete a memory or memories for the user"""
if message.is_direct_message:
usage_context = self.usage_context.DIRECT
source_type = "direct"
source = message.user_id
else:
usage_context = self.usage_context.CHANNEL
source_type = "channel"
source = message.channel_id
if mode.lower() not in ["all", "id", "context"]:
return "Error: mode must be all, id, or context"
if mode.lower() == "id":
if memory_id is None:
return "Error: memory_id must be set when mode is id"
self.vectordb.user_delete_memory_by_id(memory_id, message.user_id, usage_context, source_type, source)
elif mode.lower() == "all":
self.vectordb.user_delete_all_memories(message.user_id)
else:
self.vectordb.user_delete_all_memories_in_context(message.user_id, usage_context, source_type, source)
return f"Memory or memories deleted for {mode} mode"
@listen_to("^.gpt memories get")
async def get_user_memories(self, message: Message, tool_run=False):
"""Get all memories for the user"""
Expand Down
20 changes: 16 additions & 4 deletions plugins/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ def store_multiple(self, table: str, contexts: list, tags: list, content: list):
"""
def user_has_memories(self, user: str, usage_context: UsageContext, source_type: str, source: str):
"""Check if a user has memories."""
result = self.conn.execute(SQL("SELECT id FROM {} WHERE created_by = %s AND usage_context = %s AND source_type = %s AND source = %s").format(Identifier(self.DEFAULT_TABLE)), (user, usage_context.value.lower(), source_type, source))
result = self.conn.execute(SQL("SELECT id FROM {} WHERE created_by = %s AND usage_context = %s AND source_type = %s AND source = %s AND is_deleted is FALSE").format(Identifier(self.DEFAULT_TABLE)), (user, usage_context.value.lower(), source_type, source))
if len(result.fetchall()) > 0:
return True
return False
def check_if_memory_exists(self, usage_context: UsageContext, content, user: str, source_type: str, source: str):
"""Check if a memory exists."""
result = self.conn.execute(SQL("SELECT id FROM {} WHERE content = %s AND created_by = %s AND usage_context = %s AND source_type = %s and source = %s").format(Identifier(self.DEFAULT_TABLE)), (content, user, usage_context.value.lower(), source_type, source))
result = self.conn.execute(SQL("SELECT id FROM {} WHERE content = %s AND created_by = %s AND usage_context = %s AND source_type = %s and source = %s AND is_deleted = FALSE").format(Identifier(self.DEFAULT_TABLE)), (content, user, usage_context.value.lower(), source_type, source))
if len(result.fetchall()) > 0:
return True
return False
Expand All @@ -174,8 +174,19 @@ def get_memories(self, query: str, usage_context: UsageContext, user: str, sourc
return self.search(self.DEFAULT_TABLE, query=query, usage_context=usage_context, category="memory", user=user, source_type=source_type, source=source, limit=limit)
def get_all_memories_for_user_for_context(self, user: str, usage_context: UsageContext, source_type: str, source: str):
"""Get all memories for a user."""
return self.conn.execute(SQL("SELECT id, created_at, tags, content FROM {} WHERE created_by = %s AND usage_context = %s AND source_type = %s AND source = %s").format(Identifier(self.DEFAULT_TABLE)), (user,usage_context.value.lower(), source_type, source)).fetchall()

return self.conn.execute(SQL("SELECT id, created_at, tags, content FROM {} WHERE created_by = %s AND usage_context = %s AND source_type = %s AND source = %s AND is_deleted = FALSE").format(Identifier(self.DEFAULT_TABLE)), (user,usage_context.value.lower(), source_type, source)).fetchall()
def user_delete_memory_by_id(self, memory_id: int, user: str, usage_context: UsageContext, source_type: str, source: str):
"""Delete a memory."""
result = self.conn.execute(SQL("UPDATE {} SET is_deleted = TRUE WHERE id = %s AND created_by = %s AND usage_context = %s AND source_type = %s AND source = %s").format(Identifier(self.DEFAULT_TABLE)), (memory_id, user, usage_context.value.lower(), source_type, source))
return result
def user_delete_all_memories_in_context(self, user: str, usage_context: UsageContext, source_type: str, source: str):
"""Delete all memories for a user in a context."""
result = self.conn.execute(SQL("UPDATE {} SET is_deleted = TRUE WHERE created_by = %s AND usage_context = %s AND source_type = %s AND source = %s").format(Identifier(self.DEFAULT_TABLE)), (user, usage_context.value.lower(), source_type, source))
return result
def user_delete_all_memories(self, user: str):
"""Delete all memories for a user."""
result = self.conn.execute(SQL("UPDATE {} SET is_deleted = TRUE WHERE created_by = %s").format(Identifier(self.DEFAULT_TABLE)), (user,))
return result
def store(self, table: str, source_type: str, source: str, usage_context: UsageContext, category: str, tags: list, content: str, metadata: dict, created_by: str):
"""Store a memory."""
# create table if not exists
Expand All @@ -201,6 +212,7 @@ def search(self, table: str | None, query: str, user: str, usage_context: UsageC
result = self.conn.execute(
SQL("""SELECT id, tags, metadata, category, source, source_type, created_at, content, embedding <-> %s::vector as distance FROM {}
WHERE
is_deleted = FALSE AND
(embedding <-> %s::vector) < %s AND
(usage_context = %s OR usage_context = 'any')
AND (created_by = %s AND source_type = %s AND source = %s)
Expand Down

0 comments on commit 62bcdc5

Please sign in to comment.