-
Notifications
You must be signed in to change notification settings - Fork 79
/
server.py
115 lines (93 loc) · 3.34 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import asyncio
import argparse
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from typing import List
import concurrent.futures
from marker.logger import configure_logging # Import logging configuration
from marker.models import load_all_models # Import function to load models
from marker_api.routes import (
process_pdf_file,
)
from marker_api.utils import print_markerapi_text_art
from contextlib import asynccontextmanager
import logging
import gradio as gr
from marker_api.model.schema import (
BatchConversionResponse,
ConversionResponse,
HealthResponse,
ServerType,
)
from marker_api.demo import demo_ui
# Initialize logging
configure_logging()
logger = logging.getLogger(__name__)
# Global variable to hold model list
model_list = None
# Event that runs on startup to load all models
@asynccontextmanager
async def lifespan(app: FastAPI):
global model_list
logger.debug("--------------------- Loading OCR Model -----------------------")
print_markerapi_text_art()
model_list = load_all_models()
yield
# Initialize FastAPI app
app = FastAPI(lifespan=lifespan)
# Add CORS middleware to allow cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)
app = gr.mount_gradio_app(app, demo_ui, path="")
@app.get("/health", response_model=HealthResponse)
def server():
"""
Root endpoint to check server status.
"""
return HealthResponse(message="Welcome to Marker-api", type=ServerType.simple)
# Endpoint to convert a single PDF to markdown
@app.post("/convert", response_model=ConversionResponse)
async def convert_pdf_to_markdown(pdf_file: UploadFile):
"""
Endpoint to convert a single PDF to markdown.
"""
logger.debug(f"Received file: {pdf_file.filename}")
file = await pdf_file.read()
response = process_pdf_file(file, pdf_file.filename, model_list)
return ConversionResponse(status="Success", result=response)
# Endpoint to convert multiple PDFs to markdown
@app.post("/batch_convert", response_model=BatchConversionResponse)
async def convert_pdfs_to_markdown(pdf_files: List[UploadFile] = File(...)):
"""
Endpoint to convert multiple PDFs to markdown.
"""
logger.debug(f"Received {len(pdf_files)} files for batch conversion")
async def process_files(files):
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
coroutines = [
loop.run_in_executor(
pool, process_pdf_file, await file.read(), file.filename, model_list
)
for file in files
]
return await asyncio.gather(*coroutines)
responses = await process_files(pdf_files)
return BatchConversionResponse(results=responses)
# Main function to run the server
def main():
parser = argparse.ArgumentParser(description="Run the marker-api server.")
parser.add_argument("--host", default="0.0.0.0", help="Host IP address")
parser.add_argument("--port", type=int, default=8080, help="Port number")
args = parser.parse_args()
import uvicorn
uvicorn.run("server:app", host=args.host, port=args.port)
# Entry point to start the server
if __name__ == "__main__":
main()