Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for listening on a UNIX socket instead of IP #116259

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 63 additions & 5 deletions homeassistant/components/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ipaddress import IPv4Network, IPv6Network, ip_network
import logging
import os
import shutil
import socket
import ssl
from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -85,6 +86,9 @@

CONF_SERVER_HOST: Final = "server_host"
CONF_SERVER_PORT: Final = "server_port"
CONF_SOCKET_USER: Final = "socket_user"
CONF_SOCKET_GROUP: Final = "socket_group"
CONF_SOCKET_PERMISSIONS: Final = "socket_permissions"
CONF_BASE_URL: Final = "base_url"
CONF_SSL_CERTIFICATE: Final = "ssl_certificate"
CONF_SSL_PEER_CERTIFICATE: Final = "ssl_peer_certificate"
Expand Down Expand Up @@ -127,6 +131,11 @@
cv.ensure_list, vol.Length(min=1), [cv.string]
),
vol.Optional(CONF_SERVER_PORT, default=SERVER_PORT): cv.port,
vol.Optional(CONF_SOCKET_USER): vol.Any(vol.Coerce(int), vol.Coerce(str)),
vol.Optional(CONF_SOCKET_GROUP): vol.Any(vol.Coerce(int), vol.Coerce(str)),
vol.Optional(CONF_SOCKET_PERMISSIONS): vol.All(
vol.Coerce(int), vol.Range(min=0, max=0o777)
),
vol.Optional(CONF_BASE_URL): cv.string,
vol.Optional(CONF_SSL_CERTIFICATE): cv.isfile,
vol.Optional(CONF_SSL_PEER_CERTIFICATE): cv.isfile,
Expand Down Expand Up @@ -161,6 +170,9 @@ class ConfData(TypedDict, total=False):

server_host: list[str]
server_port: int
socket_user: int | str
socket_group: int | str
socket_permissions: int
base_url: str
ssl_certificate: str
ssl_peer_certificate: str
Expand Down Expand Up @@ -210,6 +222,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:

server_host = conf[CONF_SERVER_HOST]
server_port = conf[CONF_SERVER_PORT]
socket_user = conf.get(CONF_SOCKET_USER)
socket_group = conf.get(CONF_SOCKET_GROUP)
socket_permissions = conf.get(CONF_SOCKET_PERMISSIONS)
ssl_certificate = conf.get(CONF_SSL_CERTIFICATE)
ssl_peer_certificate = conf.get(CONF_SSL_PEER_CERTIFICATE)
ssl_key = conf.get(CONF_SSL_KEY)
Expand All @@ -232,6 +247,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
ssl_key=ssl_key,
trusted_proxies=trusted_proxies,
ssl_profile=ssl_profile,
socket_user=socket_user,
socket_group=socket_group,
socket_permissions=socket_permissions,
)
await server.async_initialize(
cors_origins=cors_origins,
Expand Down Expand Up @@ -320,6 +338,9 @@ def __init__(
server_port: int,
trusted_proxies: list[IPv4Network | IPv6Network],
ssl_profile: str,
socket_user: int | str | None,
socket_group: int | str | None,
socket_permissions: int | None,
) -> None:
"""Initialize the HTTP Home Assistant server."""
self.app = HomeAssistantApplication(
Expand All @@ -342,8 +363,11 @@ def __init__(
self.server_port = server_port
self.trusted_proxies = trusted_proxies
self.ssl_profile = ssl_profile
self.socket_user = socket_user
self.socket_group = socket_group
self.socket_permissions = socket_permissions
self.runner: web.AppRunner | None = None
self.site: HomeAssistantTCPSite | None = None
self.site: web.BaseSite | None = None
self.context: ssl.SSLContext | None = None

async def async_initialize(
Expand Down Expand Up @@ -563,17 +587,51 @@ async def start(self) -> None:
)
await self.runner.setup()

self.site = HomeAssistantTCPSite(
self.runner, self.server_host, self.server_port, ssl_context=self.context
)
socket_path: str | None = None
if self.server_host and self.server_host[0].startswith("unix:"):
socket_path = self.server_host[0].removeprefix("unix:")
self.site = web.UnixSite(
self.runner,
socket_path,
ssl_context=self.context,
)
else:
self.site = HomeAssistantTCPSite(
self.runner,
self.server_host,
self.server_port,
ssl_context=self.context,
)
try:
await self.site.start()
except OSError as error:
_LOGGER.error(
"Failed to create HTTP server at port %d: %s", self.server_port, error
)

_LOGGER.info("Now listening on port %d", self.server_port)
if socket_path is not None:
# They didn't find a way to put this in aiohttp yet so we have to do it here
# https://github.com/aio-libs/aiohttp/issues/4155#issuecomment-643509809
if self.socket_permissions is not None:
try:
os.chmod(socket_path, self.socket_permissions)
except OSError as error:
_LOGGER.error(
"Failed to change permissions on %s: %s", socket_path, error
)
if self.socket_user is not None or self.socket_group is not None:
try:
shutil.chown(
socket_path,
self.socket_user or -1,
self.socket_group or -1,
)
except OSError as error:
_LOGGER.error(
"Failed to change user/group on %s: %s", socket_path, error
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does blocking I/O in the event loop

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes that's a pretty silly thing to just forget.


_LOGGER.info("Now listening on %s", self.site.name)

async def stop(self) -> None:
"""Stop the aiohttp server."""
Expand Down
7 changes: 6 additions & 1 deletion homeassistant/components/http/forwarded.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Awaitable, Callable
from ipaddress import IPv4Network, IPv6Network, ip_address
import logging
import socket

from aiohttp.hdrs import X_FORWARDED_FOR, X_FORWARDED_HOST, X_FORWARDED_PROTO
from aiohttp.web import Application, HTTPBadRequest, Request, StreamResponse, middleware
Expand Down Expand Up @@ -90,7 +91,11 @@ async def forwarded_middleware(
# Connected IP isn't retrieveable from the request transport, continue
return await handler(request)

connected_ip = ip_address(request.transport.get_extra_info("peername")[0])
if request.transport.get_extra_info("socket").family == socket.AF_UNIX:
# UNIX sockets won't have a peername but always come from localhost anyway
connected_ip = ip_address("127.0.0.1")
else:
connected_ip = ip_address(request.transport.get_extra_info("peername")[0])

# We have X-Forwarded-For, but config does not agree
if not use_x_forwarded_for:
Expand Down