Skip to content

Commit

Permalink
refactor: re-impl the server
Browse files Browse the repository at this point in the history
  • Loading branch information
TrickyPi committed May 17, 2024
1 parent 343c681 commit d4768fe
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 88 deletions.
100 changes: 97 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ edition = "2021"
clap = { version = "4.5", features = ["derive"] }
hyper = { version = "1.3.1", features = ["full"] }
tokio = { version = "1.28.2", features = ["full"] }
http-body-util = "0.1.1"
hyper-util = { version = "0.1.3", features = ["full"] }
local-ip-address = "0.6.1"
colored = "2.1.0"
9 changes: 5 additions & 4 deletions src/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@ impl Addr {
network_ip: local_ip().unwrap(),
}
}
pub fn is_free_port(&self, port: Port) -> Option<(SocketAddr, SocketAddr)> {
pub fn is_free_port(&self, port: Port) -> Option<(SocketAddr, SocketAddr, String)> {
let Addr {
local_ip,
network_ip,
} = self;
let local_addr = SocketAddr::new(IpAddr::V4(*local_ip), port);
let network_addr = SocketAddr::new(*network_ip, port);
if TcpListener::bind(local_addr).is_ok() | TcpListener::bind(network_addr).is_ok() {
return Some((local_addr, network_addr));
let bind_addr = format!("[::]:{}", port);
if TcpListener::bind(&bind_addr).is_ok() {
return Some((local_addr, network_addr, bind_addr));
}
None
}
pub fn get_address(&self, port: Port) -> (SocketAddr, SocketAddr) {
pub fn get_address(&self, port: Port) -> (SocketAddr, SocketAddr, String) {
if let Some(addr) = self.is_free_port(port) {
return addr;
}
Expand Down
120 changes: 61 additions & 59 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use clap::Parser;
use hyper::header;
use hyper::header::HeaderValue;
use hyper::service::{make_service_fn, service_fn};
use hyper::{header, Client, Server};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use std::convert::Infallible;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;

use chost::addr::{Addr, Port};
use chost::response::{file::response_file_content, proxy::proxy_response};
Expand All @@ -15,26 +17,27 @@ use chost::utils::get_full_addr_string;
#[derive(Parser)]
struct Cli {
/// path to host
#[clap(parse(from_os_str))]
path: Option<PathBuf>,
/// enable cors
#[clap(short, long)]
#[arg(short, long)]
cors: bool,
/// port
#[clap(short, long, default_value_t = 7878)]
#[arg(short, long, default_value_t = 7878)]
port: Port,
/// forwarding request to other service, the format is "${api}|${origin} ${api}|${origin}"
#[clap(long, value_delimiter = ' ')]
#[arg(long, value_delimiter = ' ')]
proxy: Option<Vec<String>>,
}

#[tokio::main]
async fn main() {
let args = Cli::parse();
create_server(args).await;
if let Err(e) = create_server(args).await {
eprintln!("server error: {}", e);
};
}

async fn create_server(args: Cli) {
async fn create_server(args: Cli) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let Cli {
port,
path,
Expand Down Expand Up @@ -65,63 +68,62 @@ async fn create_server(args: Cli) {
});

let proxies_arc = Arc::new(proxies);
let cors_arc = Arc::new(cors);

let client = Client::builder()
.pool_idle_timeout(Duration::from_secs(1000))
.build_http::<hyper::Body>();
let addr = Addr::new();
let (local_addr, network_addr, bind_addr) = addr.get_address(port);

let listener = TcpListener::bind(bind_addr).await?;

println!("local server on {}", get_full_addr_string(&local_addr));
println!("network server on {}", get_full_addr_string(&network_addr));

let make_svc = make_service_fn(|_| {
loop {
let path = path.clone();
let cors_arc = cors_arc.clone();
let proxies_arc = proxies_arc.clone();
let client = client.clone();
async move {
Ok::<_, Infallible>(service_fn(move |req| {
let req_path = req.uri().path().strip_prefix('/').unwrap().to_owned();
let method = req.method().clone();

let path = path.clone();
let cors_arc = cors_arc.clone();
let proxies_arc = proxies_arc.clone();
let client = client.clone();

async move {
let mut resp =
if let Some(resp) = proxy_response(client, req, &proxies_arc).await {
resp
} else {
response_file_content(path, method, req_path).await
};
if *cors_arc {
let headers = resp.headers_mut();
let not_limited_value = HeaderValue::from_static("*");
headers.insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
not_limited_value.clone(),
);
headers.insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
not_limited_value.clone(),
);
headers.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, not_limited_value);
}
Ok::<_, Infallible>(resp)
}
}))
}
});

let addr = Addr::new();
let (local_addr, network_addr) = addr.get_address(port);
let (stream, _addr) = listener.accept().await?;

let local_server = Server::bind(&local_addr).serve(make_svc);
let network_server = Server::bind(&network_addr).serve(make_svc);
let io = TokioIo::new(stream);

println!("local server on {}", get_full_addr_string(&local_addr));
println!("network server on {}", get_full_addr_string(&network_addr));
tokio::task::spawn(async move {
http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| {
let req_path = req.uri().path().strip_prefix('/').unwrap().to_owned();
let method = req.method().clone();

if let Err(e) = tokio::try_join!(local_server, network_server) {
eprintln!("server error: {}", e);
let path = path.clone();
let proxies_arc = proxies_arc.clone();

async move {
let mut resp =
if let Some(resp) = proxy_response(req, &proxies_arc).await {
resp
} else {
response_file_content(path, method, req_path).await
};
if cors {
let headers = resp.headers_mut();
let not_limited_value = HeaderValue::from_static("*");
headers.insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
not_limited_value.clone(),
);
headers.insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
not_limited_value.clone(),
);
headers.insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
not_limited_value,
);
}
Ok::<_, Infallible>(resp)
}
}),
)
.await
});
}
}

0 comments on commit d4768fe

Please sign in to comment.