Skip to content

Commit

Permalink
feat: protect api with AUTHORIZATION (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
xsigoking committed Apr 11, 2024
1 parent 6b5097b commit f1ae2b7
Showing 1 changed file with 40 additions and 6 deletions.
46 changes: 40 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,47 @@ const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
async fn main() -> Result<()> {
init_logger();

let mut has_envs = [false; 3];

let port = if let Ok(port) = env::var("PORT") {
has_envs[0] = true;
port.parse::<u16>()
.map_err(|_| anyhow!("Invalid environment variable $PORT"))?
} else {
PORT
};
let mut client_builder = ClientBuilder::new().connect_timeout(CONNECT_TIMEOUT);
if let Ok(proxy) = env::var("ALL_PROXY") {
has_envs[1] = true;
client_builder = client_builder.proxy(
Proxy::all(proxy)
.map_err(|err| anyhow!("Invalid environment variable $ALL_PROXY, {err}"))?,
);
};
let listener = tokio::net::TcpListener::bind(&format!("0.0.0.0:{port}")).await?;

let authorization = env::var("AUTHORIZATION").ok().and_then(|v| {
if v.is_empty() {
None
} else {
has_envs[2] = true;
Some(v)
}
});
let server = Arc::new(Server {
client: client_builder.build()?,
authorization,
});
let [port_has_env, all_proxy_has_env, authorization_has_env] =
has_envs.map(|v| if v { " ✅" } else { "" });
let stop_server = server.run(listener).await?;
println!(
r#"Access the API server at: http://0.0.0.0:{port}/v1/chat/completions
Environment Variables:
- PORT: change the listening port, defaulting to {PORT}
- ALL_PROXY: set the proxy server
- PORT: change the listening port, defaulting to {PORT}{port_has_env}
- ALL_PROXY: configure the proxy server, supporting HTTP, HTTPS, and SOCKS5 protocols{all_proxy_has_env}
- AUTHORIZATION: only for internal use to protect the API and will not be sent to OpenAI{authorization_has_env}
Please contact us at https://github.com/xsigoking/chatgpt-free-api if you encounter any issues.
"#
Expand All @@ -79,6 +95,7 @@ type AppResponse = Response<BoxBody<Bytes, Infallible>>;

struct Server {
client: Client,
authorization: Option<String>,
}

impl Server {
Expand Down Expand Up @@ -121,24 +138,41 @@ impl Server {
) -> std::result::Result<AppResponse, hyper::Error> {
let method = req.method().clone();
let uri = req.uri().clone();
let res = if method == Method::POST && uri == "/v1/chat/completions" {
let mut auth_failed = false;
if let Some(expect_authorization) = &self.authorization {
if let Some(authorization) = req.headers().get("authorization") {
if authorization.as_bytes() != expect_authorization.as_bytes() {
auth_failed = true;
}
} else {
auth_failed = true;
}
}
let mut status = StatusCode::OK;
let res = if auth_failed {
status = StatusCode::UNAUTHORIZED;
Err(anyhow!(
"No authorization header or invalid authorization value."
))
} else if method == Method::POST && uri == "/v1/chat/completions" {
self.chat_completion(req).await
} else if method == Method::GET && uri == "/v1/models" {
self.models(req).await
} else {
status = StatusCode::NOT_FOUND;
Err(anyhow!("The requested endpoint was not found."))
};
let mut res = match res {
Ok(res) => {
info!("{} {}", method, uri);
info!("{method} {uri} {}", status.as_u16());
res
}
Err(err) => {
info!("{} {}", method, uri);
error!("api error: {err}");
error!("{method} {uri} {} {err}", status.as_u16());
create_error_response(err)
}
};
*res.status_mut() = status;
set_cors_header(&mut res);
Ok(res)
}
Expand Down

0 comments on commit f1ae2b7

Please sign in to comment.