diff --git a/samples/apps/autogen-studio/autogenstudio/datamodel.py b/samples/apps/autogen-studio/autogenstudio/datamodel.py index ee48818d599f..04239e49b7e9 100644 --- a/samples/apps/autogen-studio/autogenstudio/datamodel.py +++ b/samples/apps/autogen-studio/autogenstudio/datamodel.py @@ -2,6 +2,7 @@ from enum import Enum from typing import Any, Callable, Dict, List, Literal, Optional, Union +from pydantic import BaseModel from sqlalchemy import ForeignKey, Integer, orm from sqlmodel import ( JSON, @@ -295,3 +296,32 @@ class SocketMessage(SQLModel, table=False): connection_id: str data: Dict[str, Any] type: str + + +class ExportedAgent(BaseModel): + id: Optional[int] = None + user_id: Optional[str] = None + version: Optional[str] = "0.0.1" + type: AgentType = AgentType.assistant + config: Union[AgentConfig, dict] + skills: List[Skill] + models: List[Model] + agents: List["ExportedAgent"] + task_instruction: Optional[str] = None + + +class ExportedAgentWithLink(BaseModel): + agent: ExportedAgent + link: WorkflowAgentLink + + +class ExportedWorkflow(BaseModel): + id: Optional[int] = None + user_id: str = None + name: str + summary_method: str = WorkFlowSummaryMethod.last + sample_tasks: Optional[List[str]] + version: str = "0.0.1" + description: str + type: str = WorkFlowType.autonomous + agents: List[ExportedAgentWithLink] diff --git a/samples/apps/autogen-studio/autogenstudio/web/app.py b/samples/apps/autogen-studio/autogenstudio/web/app.py index bbd087f52ea2..7f9025bdf664 100644 --- a/samples/apps/autogen-studio/autogenstudio/web/app.py +++ b/samples/apps/autogen-studio/autogenstudio/web/app.py @@ -4,7 +4,7 @@ import threading import traceback from contextlib import asynccontextmanager -from typing import Any, Union +from typing import Any, List, Optional, Union from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware @@ -15,7 +15,7 @@ from ..chatmanager import AutoGenChatManager from ..database import workflow_from_id from ..database.dbmanager import DBManager -from ..datamodel import Agent, Message, Model, Response, Session, Skill, Workflow +from ..datamodel import Agent, ExportedAgent, ExportedWorkflow, Message, Model, Response, Session, Skill, Workflow from ..profiler import Profiler from ..utils import check_and_cast_datetime_fields, init_app_folders, md5_hash, test_model from ..version import VERSION @@ -322,6 +322,43 @@ async def export_workflow(workflow_id: int, user_id: str): return response.model_dump(mode="json") +@api.post("/workflows/import") +async def import_workflow(exported_workflow: ExportedWorkflow): + """Import a user workflow""" + + async def create_agent_with_links(agent_data: ExportedAgent, parent_agent_id: Optional[int] = None): + agent = Agent(**agent_data.model_dump(exclude={"id", "models", "skills", "agents"})) + await create_agent(agent) + + for model in agent_data.models: + model.id = None + await create_model(model) + await link_agent_model(agent.id, model.id) + + for skill in agent_data.skills: + skill.id = None + await create_skill(skill) + await link_agent_skill(agent.id, skill.id) + + if parent_agent_id: + await link_agent_agent(parent_agent_id, agent.id) + + for nested_agent_data in agent_data.agents: + await create_agent_with_links(nested_agent_data, agent.id) + + return agent + + workflow = Workflow(**exported_workflow.model_dump(exclude={"id", "agents"})) + await create_workflow(workflow) + + for agent_with_link in exported_workflow.agents: + created_agent = await create_agent_with_links(agent_with_link.agent, None) + await link_workflow_agent(workflow.id, created_agent.id, agent_with_link.link.agent_type) + + response: Response = Response(status=True, message="Imported workflow", data=workflow) + return response + + @api.post("/workflows") async def create_workflow(workflow: Workflow): """Create a new workflow"""