Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Swiss public transport via stations #115891

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 12 additions & 8 deletions homeassistant/components/swiss_public_transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession

from .const import CONF_DESTINATION, CONF_START, DOMAIN
from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN
from .coordinator import SwissPublicTransportDataUpdateCoordinator
from .helper import unique_id_from_config

_LOGGER = logging.getLogger(__name__)

Expand All @@ -32,18 +33,19 @@ async def async_setup_entry(
start = config[CONF_START]
destination = config[CONF_DESTINATION]

unique_id = unique_id_from_config(config)
session = async_get_clientsession(hass)
opendata = OpendataTransport(start, destination, session)
opendata = OpendataTransport(start, destination, session, via=config.get(CONF_VIA))

try:
await opendata.async_get_data()
except OpendataTransportConnectionError as e:
raise ConfigEntryNotReady(
f"Timeout while connecting for entry '{start} {destination}'"
f"Timeout while connecting for entry '{unique_id}'"
) from e
except OpendataTransportError as e:
raise ConfigEntryError(
f"Setup failed for entry '{start} {destination}' with invalid data, check "
f"Setup failed for entry '{unique_id}' with invalid data, check "
"at http://transport.opendata.ch/examples/stationboard.html if your "
"station names are valid"
) from e
Expand Down Expand Up @@ -76,11 +78,9 @@ async def async_migrate_entry(
# This means the user has downgraded from a future version
return False

if config_entry.minor_version == 1:
if config_entry.version == 1 and config_entry.minor_version == 1:
# Remove wrongly registered devices and entries
new_unique_id = (
f"{config_entry.data[CONF_START]} {config_entry.data[CONF_DESTINATION]}"
)
new_unique_id = unique_id_from_config(config_entry.data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This addition of via should probably bump the config entry version again as it's not really backwards compatible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that it is forward compatible and does not need a migration but it going back will break. Is that correct? @gjohansson-ST

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this doesn't need to bump the major version as with the addition of via in the unique id I don't think it would work if a user goes backwards in versions.
I don't think bumping a minor version is enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not that familiar with this kind of policy within home assistant. Currently I only add via stations, but there are some more useful options in the api which will require another jump in the future. What would you suggest? Be on the save side and jump a major? @gjohansson-ST

Copy link
Contributor Author

@miaucl miaucl Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gjohansson-ST We already had a discussion about that here: #107087 (comment)

My interpretation would be to make sure one cannot bump back to a prior version (1.2) as it would break and therefore switch the version bump to a major version to (2.1) instead of (1.3).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also extended the unit test for the migration to be able to test all different starting versions using parametrisation.

entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass)
device_entries = dr.async_entries_for_config_entry(
Expand Down Expand Up @@ -109,6 +109,10 @@ async def async_migrate_entry(
config_entry, unique_id=new_unique_id, minor_version=2
)

if config_entry.version == 1 and config_entry.minor_version == 2:
# Via stations now available, migrate to version 2.1
hass.config_entries.async_update_entry(config_entry, version=2, minor_version=1)

_LOGGER.debug(
"Migration to version %s.%s successful",
config_entry.version,
Expand Down
76 changes: 47 additions & 29 deletions homeassistant/components/swiss_public_transport/config_flow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Config flow for swiss_public_transport."""

import logging
from types import MappingProxyType
from typing import Any

from opendata_transport import OpendataTransport
Expand All @@ -14,12 +15,24 @@
from homeassistant.const import CONF_NAME
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.selector import (
TextSelector,
TextSelectorConfig,
TextSelectorType,
)

from .const import CONF_DESTINATION, CONF_START, DOMAIN, PLACEHOLDERS
from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, MAX_VIA, PLACEHOLDERS
from .helper import unique_id_from_config

DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_START): cv.string,
vol.Optional(CONF_VIA): TextSelector(
TextSelectorConfig(
type=TextSelectorType.TEXT,
multiple=True,
),
),
vol.Required(CONF_DESTINATION): cv.string,
}
)
Expand All @@ -30,38 +43,43 @@
class SwissPublicTransportConfigFlow(ConfigFlow, domain=DOMAIN):
"""Swiss public transport config flow."""

VERSION = 1
MINOR_VERSION = 2
VERSION = 2
MINOR_VERSION = 1

async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Async user step to set up the connection."""
errors: dict[str, str] = {}
if user_input is not None:
await self.async_set_unique_id(
f"{user_input[CONF_START]} {user_input[CONF_DESTINATION]}"
)
unique_id = unique_id_from_config(MappingProxyType(user_input))
await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()

session = async_get_clientsession(self.hass)
opendata = OpendataTransport(
user_input[CONF_START], user_input[CONF_DESTINATION], session
)
try:
await opendata.async_get_data()
except OpendataTransportConnectionError:
errors["base"] = "cannot_connect"
except OpendataTransportError:
errors["base"] = "bad_config"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unknown error")
errors["base"] = "unknown"
if CONF_VIA in user_input and len(user_input[CONF_VIA]) > MAX_VIA:
errors["base"] = "too_many_via_stations"
else:
return self.async_create_entry(
title=f"{user_input[CONF_START]} {user_input[CONF_DESTINATION]}",
data=user_input,
session = async_get_clientsession(self.hass)
opendata = OpendataTransport(
user_input[CONF_START],
user_input[CONF_DESTINATION],
session,
via=user_input.get(CONF_VIA),
)
try:
await opendata.async_get_data()
except OpendataTransportConnectionError:
errors["base"] = "cannot_connect"
except OpendataTransportError:
errors["base"] = "bad_config"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unknown error")
errors["base"] = "unknown"
else:
return self.async_create_entry(
title=unique_id,
data=user_input,
)

return self.async_show_form(
step_id="user",
Expand All @@ -72,14 +90,15 @@ async def async_step_user(

async def async_step_import(self, import_input: dict[str, Any]) -> ConfigFlowResult:
"""Async import step to set up the connection."""
await self.async_set_unique_id(
f"{import_input[CONF_START]} {import_input[CONF_DESTINATION]}"
)
unique_id = unique_id_from_config(MappingProxyType(import_input))
await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()

session = async_get_clientsession(self.hass)
opendata = OpendataTransport(
import_input[CONF_START], import_input[CONF_DESTINATION], session
import_input[CONF_START],
import_input[CONF_DESTINATION],
session,
)
try:
await opendata.async_get_data()
Expand All @@ -89,9 +108,8 @@ async def async_step_import(self, import_input: dict[str, Any]) -> ConfigFlowRes
return self.async_abort(reason="bad_config")
except Exception: # pylint: disable=broad-except
_LOGGER.error(
"Unknown error raised by python-opendata-transport for '%s %s', check at http://transport.opendata.ch/examples/stationboard.html if your station names and your parameters are valid",
import_input[CONF_START],
import_input[CONF_DESTINATION],
"Unknown error raised by python-opendata-transport for '%s', check at http://transport.opendata.ch/examples/stationboard.html if your station names and your parameters are valid",
unique_id,
)
return self.async_abort(reason="unknown")

Expand Down
8 changes: 6 additions & 2 deletions homeassistant/components/swiss_public_transport/const.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Constants for the swiss_public_transport integration."""

from typing import Final

DOMAIN = "swiss_public_transport"

CONF_DESTINATION = "to"
CONF_START = "from"
CONF_DESTINATION: Final = "to"
CONF_START: Final = "from"
CONF_VIA: Final = "via"

DEFAULT_NAME = "Next Destination"

MAX_VIA = 5
SENSOR_CONNECTIONS_COUNT = 3


Expand Down
15 changes: 15 additions & 0 deletions homeassistant/components/swiss_public_transport/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Helper functions for swiss_public_transport."""

from types import MappingProxyType
from typing import Any

from .const import CONF_DESTINATION, CONF_START, CONF_VIA


def unique_id_from_config(config: MappingProxyType[str, Any]) -> str:
"""Build a unique id from a config entry."""
return f"{config[CONF_START]} {config[CONF_DESTINATION]}" + (
" via " + ", ".join(config[CONF_VIA])
if CONF_VIA in config and len(config[CONF_VIA]) > 0
else ""
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"error": {
"cannot_connect": "Cannot connect to server",
"bad_config": "Request failed due to bad config: Check at [stationboard]({stationboard_url}) if your station names are valid",
"too_many_via_stations": "Too many via stations, only up to 5 via stations are allowed per connection.",
"unknown": "An unknown error was raised by python-opendata-transport"
},
"abort": {
Expand All @@ -15,9 +16,10 @@
"user": {
"data": {
"from": "Start station",
"to": "End station"
"to": "End station",
"via": "List of up to 5 via stations"
},
"description": "Provide start and end station for your connection\n\nCheck the [stationboard]({stationboard_url}) for valid stations.",
"description": "Provide start and end station for your connection,\nand optionally up to 5 via stations.\n\nCheck the [stationboard]({stationboard_url}) for valid stations.",
"title": "Swiss Public Transport"
}
}
Expand Down
58 changes: 45 additions & 13 deletions tests/components/swiss_public_transport/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from homeassistant.components.swiss_public_transport.const import (
CONF_DESTINATION,
CONF_START,
CONF_VIA,
MAX_VIA,
)
from homeassistant.components.swiss_public_transport.helper import unique_id_from_config
from homeassistant.const import CONF_NAME
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
Expand All @@ -27,8 +30,36 @@
CONF_DESTINATION: "test_destination",
}

MOCK_DATA_STEP_ONE_VIA = {
**MOCK_DATA_STEP,
CONF_VIA: ["via_station"],
}

MOCK_DATA_STEP_MANY_VIA = {
**MOCK_DATA_STEP,
CONF_VIA: ["via_station_1", "via_station_2", "via_station_3"],
}

MOCK_DATA_STEP_TOO_MANY_STATIONS = {
**MOCK_DATA_STEP,
CONF_VIA: MOCK_DATA_STEP_ONE_VIA[CONF_VIA] * (MAX_VIA + 1),
}


async def test_flow_user_init_data_success(hass: HomeAssistant) -> None:
@pytest.mark.parametrize(
("user_input", "config_title"),
[
(MOCK_DATA_STEP, "test_start test_destination"),
(MOCK_DATA_STEP_ONE_VIA, "test_start test_destination via via_station"),
(
MOCK_DATA_STEP_MANY_VIA,
"test_start test_destination via via_station_1, via_station_2, via_station_3",
),
],
)
async def test_flow_user_init_data_success(
hass: HomeAssistant, user_input, config_title
) -> None:
"""Test success response."""
result = await hass.config_entries.flow.async_init(
config_flow.DOMAIN, context={"source": "user"}
Expand All @@ -49,25 +80,26 @@ async def test_flow_user_init_data_success(hass: HomeAssistant) -> None:
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=MOCK_DATA_STEP,
user_input=user_input,
)

assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["result"].title == "test_start test_destination"
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["result"].title == config_title

assert result["data"] == MOCK_DATA_STEP
assert result["data"] == user_input


@pytest.mark.parametrize(
("raise_error", "text_error"),
("raise_error", "text_error", "user_input_error"),
[
(OpendataTransportConnectionError(), "cannot_connect"),
(OpendataTransportError(), "bad_config"),
(IndexError(), "unknown"),
(OpendataTransportConnectionError(), "cannot_connect", MOCK_DATA_STEP),
(OpendataTransportError(), "bad_config", MOCK_DATA_STEP),
(None, "too_many_via_stations", MOCK_DATA_STEP_TOO_MANY_STATIONS),
(IndexError(), "unknown", MOCK_DATA_STEP),
],
)
async def test_flow_user_init_data_error_and_recover(
hass: HomeAssistant, raise_error, text_error
hass: HomeAssistant, raise_error, text_error, user_input_error
) -> None:
"""Test unknown errors."""
with patch(
Expand All @@ -80,7 +112,7 @@ async def test_flow_user_init_data_error_and_recover(
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=MOCK_DATA_STEP,
user_input=user_input_error,
)

assert result["type"] is FlowResultType.FORM
Expand All @@ -94,7 +126,7 @@ async def test_flow_user_init_data_error_and_recover(
user_input=MOCK_DATA_STEP,
)

assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["result"].title == "test_start test_destination"

assert result["data"] == MOCK_DATA_STEP
Expand All @@ -106,7 +138,7 @@ async def test_flow_user_init_data_already_configured(hass: HomeAssistant) -> No
entry = MockConfigEntry(
domain=config_flow.DOMAIN,
data=MOCK_DATA_STEP,
unique_id=f"{MOCK_DATA_STEP[CONF_START]} {MOCK_DATA_STEP[CONF_DESTINATION]}",
unique_id=unique_id_from_config(MOCK_DATA_STEP),
)
entry.add_to_hass(hass)

Expand Down