Skip to content

Commit

Permalink
fixed server crash on unauthed websocket connections, added ip and po…
Browse files Browse the repository at this point in the history
…rt as arguments
  • Loading branch information
catink123 committed Jan 21, 2024
1 parent 331d250 commit c83cea5
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 12 deletions.
10 changes: 7 additions & 3 deletions src/auth.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ namespace base64 = beast::detail::base64;
namespace fs = std::filesystem;

enum AuthorizationType {
View = 0,
Control = 1,
Blocked = 2
Blocked = 0,
View = 1,
Control = 2,
};

const std::unordered_map<std::string, std::optional<AuthorizationType>> endpoint_map = {
Expand Down Expand Up @@ -69,6 +69,10 @@ std::optional<AuthorizationType> get_auth(
const http::request<Body, http::basic_fields<Allocator>>& req,
const std::unordered_map<std::string, auth_data>& auth_table
) {
if (req.find(http::field::authorization) == req.end()) {
return std::nullopt;
}

// get the authorization field and separate the base64 encoded user-pass combination
const std::string authorization = req.at(http::field::authorization);
const std::string user_pass = authorization.substr(authorization.find(' ') + 1);
Expand Down
4 changes: 2 additions & 2 deletions src/http_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ http::message_generator handle_request(
};

res.set(http::field::server, VERSION);
res.set(http::field::www_authenticate, "Basic realm=\"control\"");
res.set(http::field::www_authenticate, "Basic realm=\"viewcontrol\"");
res.keep_alive(req.keep_alive());
res.body() = "Unauthorized client on resource '" + std::string(target) + "'.";
res.prepare_payload();
Expand Down Expand Up @@ -305,7 +305,7 @@ http::message_generator handle_request(
return unauthorized(req.target());
}

if (permissions != endpoint_perms) {
if (permissions < endpoint_perms) {
return forbidden(req.target());
}
}
Expand Down
61 changes: 54 additions & 7 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,58 @@
#include <string>
#include <vector>
#include <thread>
#include <limits>
#include "common.hpp"
#include "http_listener.hpp"
#include "common_state.hpp"

const auto ADDRESS = net::ip::make_address_v4("0.0.0.0");
const auto PORT = static_cast<unsigned short>(80);
const auto DEFAULT_ADDRESS = net::ip::make_address_v4("0.0.0.0");
const auto DEFAULT_PORT = static_cast<unsigned short>(80);
const auto DOC_ROOT = std::make_shared<std::string>("./client");
const auto THREAD_COUNT = 8;

int main(int argc, char* argv[]) {
if (argc < 3) {
printf("Usage: %s <com-port> <auth-file>", argv[0]);
printf("Usage: %s <com-port> <auth-file> [<ipv4-address>] [<port>]", argv[0]);
return 1;
}

std::optional<net::ip::address_v4> address;
std::optional<unsigned short> port;

if (argc == 4) {
try {
address = net::ip::make_address_v4(argv[3]);
}
catch (...) {
std::cerr << "Invalid IPv4 Address passed in an argument." << std::endl;
return 1;
}
}

if (!address) {
address = DEFAULT_ADDRESS;
}

if (argc == 5) {
try {
int port_int = std::stoi(argv[4]);
if (port_int > std::numeric_limits<unsigned short>::max() || port_int < 0) {
std::cerr << "Port passed as an argument is invalid." << std::endl;
return 1;
}
port = static_cast<unsigned short>(port_int);
}
catch (...) {
std::cerr << "Port passed as an argument is invalid." << std::endl;
return 1;
}
}

if (!port) {
port = DEFAULT_PORT;
}

fs::path auth_file_path;

try {
Expand Down Expand Up @@ -60,14 +97,14 @@ int main(int argc, char* argv[]) {

std::make_shared<http_listener>(
ioc,
tcp::endpoint{ADDRESS, PORT},
tcp::endpoint{address.value(), port.value()},
DOC_ROOT,
comstate,
arduino_connection,
auth_table_ptr
)->run();

std::cout << "Server started at " << ADDRESS << ":" << PORT << "." << std::endl;
std::cout << "Server started at " << DEFAULT_ADDRESS << ":" << DEFAULT_PORT << "." << std::endl;

// graceful shutdown
net::signal_set signals(ioc, SIGINT, SIGTERM);
Expand All @@ -84,11 +121,21 @@ int main(int argc, char* argv[]) {
for (auto i = 0; i < THREAD_COUNT - 1; ++i) {
v.emplace_back(
[&ioc] {
ioc.run();
try {
ioc.run();
}
catch (const std::exception& ex) {
std::cerr << "Stopping thread because of an unhandled exception: " << ex.what() << std::endl;
}
}
);
}
ioc.run();
try {
ioc.run();
}
catch (const std::exception& ex) {
std::cerr << "Stopping main thread because of an unhandled exception: " << ex.what() << std::endl;
}

// if the program is here, the graceful shutdown is in progress, wait for all threads to end
for (std::thread& th : v) {
Expand Down

0 comments on commit c83cea5

Please sign in to comment.