Skip to content

Commit

Permalink
add via stations
Browse files Browse the repository at this point in the history
  • Loading branch information
miaucl committed Apr 20, 2024
1 parent 16e31d8 commit b80f9c2
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 49 deletions.
14 changes: 7 additions & 7 deletions homeassistant/components/swiss_public_transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession

from .const import CONF_DESTINATION, CONF_START, DOMAIN
from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN
from .coordinator import SwissPublicTransportDataUpdateCoordinator
from .helper import unique_id_from_config

_LOGGER = logging.getLogger(__name__)

Expand All @@ -32,18 +33,19 @@ async def async_setup_entry(
start = config[CONF_START]
destination = config[CONF_DESTINATION]

unique_id = unique_id_from_config(config)
session = async_get_clientsession(hass)
opendata = OpendataTransport(start, destination, session)
opendata = OpendataTransport(start, destination, session, via=config.get(CONF_VIA))

try:
await opendata.async_get_data()
except OpendataTransportConnectionError as e:
raise ConfigEntryNotReady(
f"Timeout while connecting for entry '{start} {destination}'"
f"Timeout while connecting for entry '{unique_id}'"
) from e
except OpendataTransportError as e:
raise ConfigEntryError(
f"Setup failed for entry '{start} {destination}' with invalid data, check "
f"Setup failed for entry '{unique_id}' with invalid data, check "
"at http://transport.opendata.ch/examples/stationboard.html if your "
"station names are valid"
) from e
Expand Down Expand Up @@ -78,9 +80,7 @@ async def async_migrate_entry(

if config_entry.minor_version == 1:
# Remove wrongly registered devices and entries
new_unique_id = (
f"{config_entry.data[CONF_START]} {config_entry.data[CONF_DESTINATION]}"
)
new_unique_id = unique_id_from_config(config_entry.data)
entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass)
device_entries = dr.async_entries_for_config_entry(
Expand Down
72 changes: 45 additions & 27 deletions homeassistant/components/swiss_public_transport/config_flow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Config flow for swiss_public_transport."""

import logging
from types import MappingProxyType
from typing import Any

from opendata_transport import OpendataTransport
Expand All @@ -14,12 +15,24 @@
from homeassistant.const import CONF_NAME
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.selector import (
TextSelector,
TextSelectorConfig,
TextSelectorType,
)

from .const import CONF_DESTINATION, CONF_START, DOMAIN, PLACEHOLDERS
from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, MAX_VIA, PLACEHOLDERS
from .helper import unique_id_from_config

DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_START): cv.string,
vol.Optional(CONF_VIA): TextSelector(
TextSelectorConfig(
type=TextSelectorType.TEXT,
multiple=True,
),
),
vol.Required(CONF_DESTINATION): cv.string,
}
)
Expand All @@ -39,29 +52,34 @@ async def async_step_user(
"""Async user step to set up the connection."""
errors: dict[str, str] = {}
if user_input is not None:
await self.async_set_unique_id(
f"{user_input[CONF_START]} {user_input[CONF_DESTINATION]}"
)
unique_id = unique_id_from_config(MappingProxyType(user_input))
await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()

session = async_get_clientsession(self.hass)
opendata = OpendataTransport(
user_input[CONF_START], user_input[CONF_DESTINATION], session
)
try:
await opendata.async_get_data()
except OpendataTransportConnectionError:
errors["base"] = "cannot_connect"
except OpendataTransportError:
errors["base"] = "bad_config"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unknown error")
errors["base"] = "unknown"
if CONF_VIA in user_input and len(user_input[CONF_VIA]) > MAX_VIA:
errors["base"] = "too_many_via_stations"
else:
return self.async_create_entry(
title=f"{user_input[CONF_START]} {user_input[CONF_DESTINATION]}",
data=user_input,
session = async_get_clientsession(self.hass)
opendata = OpendataTransport(
user_input[CONF_START],
user_input[CONF_DESTINATION],
session,
via=user_input.get(CONF_VIA),
)
try:
await opendata.async_get_data()
except OpendataTransportConnectionError:
errors["base"] = "cannot_connect"
except OpendataTransportError:
errors["base"] = "bad_config"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unknown error")
errors["base"] = "unknown"
else:
return self.async_create_entry(
title=unique_id,
data=user_input,
)

return self.async_show_form(
step_id="user",
Expand All @@ -72,14 +90,15 @@ async def async_step_user(

async def async_step_import(self, import_input: dict[str, Any]) -> ConfigFlowResult:
"""Async import step to set up the connection."""
await self.async_set_unique_id(
f"{import_input[CONF_START]} {import_input[CONF_DESTINATION]}"
)
unique_id = unique_id_from_config(MappingProxyType(import_input))
await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()

session = async_get_clientsession(self.hass)
opendata = OpendataTransport(
import_input[CONF_START], import_input[CONF_DESTINATION], session
import_input[CONF_START],
import_input[CONF_DESTINATION],
session,
)
try:
await opendata.async_get_data()
Expand All @@ -89,9 +108,8 @@ async def async_step_import(self, import_input: dict[str, Any]) -> ConfigFlowRes
return self.async_abort(reason="bad_config")
except Exception: # pylint: disable=broad-except
_LOGGER.error(
"Unknown error raised by python-opendata-transport for '%s %s', check at http://transport.opendata.ch/examples/stationboard.html if your station names and your parameters are valid",
import_input[CONF_START],
import_input[CONF_DESTINATION],
"Unknown error raised by python-opendata-transport for '%s', check at http://transport.opendata.ch/examples/stationboard.html if your station names and your parameters are valid",
unique_id,
)
return self.async_abort(reason="unknown")

Expand Down
8 changes: 6 additions & 2 deletions homeassistant/components/swiss_public_transport/const.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Constants for the swiss_public_transport integration."""

from typing import Final

DOMAIN = "swiss_public_transport"

CONF_DESTINATION = "to"
CONF_START = "from"
CONF_DESTINATION: Final = "to"
CONF_START: Final = "from"
CONF_VIA: Final = "via"

DEFAULT_NAME = "Next Destination"

MAX_VIA = 5
SENSOR_CONNECTIONS_COUNT = 3


Expand Down
15 changes: 15 additions & 0 deletions homeassistant/components/swiss_public_transport/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Helper functions for swiss_public_transport."""

from types import MappingProxyType
from typing import Any

from .const import CONF_DESTINATION, CONF_START, CONF_VIA


def unique_id_from_config(config: MappingProxyType[str, Any]) -> str:
"""Build a unique id from a config entry."""
return f"{config[CONF_START]} {config[CONF_DESTINATION]}" + (
" via " + ", ".join(config[CONF_VIA])
if CONF_VIA in config and len(config[CONF_VIA]) > 0
else ""
)
6 changes: 4 additions & 2 deletions homeassistant/components/swiss_public_transport/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"error": {
"cannot_connect": "Cannot connect to server",
"bad_config": "Request failed due to bad config: Check at [stationboard]({stationboard_url}) if your station names are valid",
"too_many_via_stations": "Too many via stations, only up to 5 via stations are allowed per connection.",
"unknown": "An unknown error was raised by python-opendata-transport"
},
"abort": {
Expand All @@ -15,9 +16,10 @@
"user": {
"data": {
"from": "Start station",
"to": "End station"
"to": "End station",
"via": "List of up to 5 via stations"
},
"description": "Provide start and end station for your connection\n\nCheck the [stationboard]({stationboard_url}) for valid stations.",
"description": "Provide start and end station for your connection,\nand optionally up to 5 via stations.\n\nCheck the [stationboard]({stationboard_url}) for valid stations.",
"title": "Swiss Public Transport"
}
}
Expand Down
32 changes: 21 additions & 11 deletions tests/components/swiss_public_transport/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from homeassistant.components.swiss_public_transport.const import (
CONF_DESTINATION,
CONF_START,
CONF_VIA,
MAX_VIA,
)
from homeassistant.components.swiss_public_transport.helper import unique_id_from_config
from homeassistant.const import CONF_NAME
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
Expand All @@ -25,6 +28,12 @@
MOCK_DATA_STEP = {
CONF_START: "test_start",
CONF_DESTINATION: "test_destination",
CONF_VIA: ["via_station"],
}

MOCK_DATA_STEP_TOO_MANY_STATIONS = {
**MOCK_DATA_STEP,
CONF_VIA: MOCK_DATA_STEP[CONF_VIA] * (MAX_VIA + 1),
}


Expand Down Expand Up @@ -52,22 +61,23 @@ async def test_flow_user_init_data_success(hass: HomeAssistant) -> None:
user_input=MOCK_DATA_STEP,
)

assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["result"].title == "test_start test_destination"
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["result"].title == "test_start test_destination via via_station"

assert result["data"] == MOCK_DATA_STEP


@pytest.mark.parametrize(
("raise_error", "text_error"),
("raise_error", "text_error", "user_input_error"),
[
(OpendataTransportConnectionError(), "cannot_connect"),
(OpendataTransportError(), "bad_config"),
(IndexError(), "unknown"),
(OpendataTransportConnectionError(), "cannot_connect", MOCK_DATA_STEP),
(OpendataTransportError(), "bad_config", MOCK_DATA_STEP),
(None, "too_many_via_stations", MOCK_DATA_STEP_TOO_MANY_STATIONS),
(IndexError(), "unknown", MOCK_DATA_STEP),
],
)
async def test_flow_user_init_data_error_and_recover(
hass: HomeAssistant, raise_error, text_error
hass: HomeAssistant, raise_error, text_error, user_input_error
) -> None:
"""Test unknown errors."""
with patch(
Expand All @@ -80,7 +90,7 @@ async def test_flow_user_init_data_error_and_recover(
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=MOCK_DATA_STEP,
user_input=user_input_error,
)

assert result["type"] is FlowResultType.FORM
Expand All @@ -94,8 +104,8 @@ async def test_flow_user_init_data_error_and_recover(
user_input=MOCK_DATA_STEP,
)

assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["result"].title == "test_start test_destination"
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["result"].title == "test_start test_destination via via_station"

assert result["data"] == MOCK_DATA_STEP

Expand All @@ -106,7 +116,7 @@ async def test_flow_user_init_data_already_configured(hass: HomeAssistant) -> No
entry = MockConfigEntry(
domain=config_flow.DOMAIN,
data=MOCK_DATA_STEP,
unique_id=f"{MOCK_DATA_STEP[CONF_START]} {MOCK_DATA_STEP[CONF_DESTINATION]}",
unique_id=unique_id_from_config(MOCK_DATA_STEP),
)
entry.add_to_hass(hass)

Expand Down

0 comments on commit b80f9c2

Please sign in to comment.