Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor group state logic #116318

Merged
merged 11 commits into from
May 2, 2024
44 changes: 39 additions & 5 deletions homeassistant/components/group/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
from typing import Any

from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_ON
from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_OFF, STATE_ON
from homeassistant.core import (
CALLBACK_TYPE,
Event,
Expand Down Expand Up @@ -133,6 +133,7 @@ class Group(Entity):
_attr_should_poll = False
tracking: tuple[str, ...]
trackable: tuple[str, ...]
single_state_type_key: tuple[str, str] | None

def __init__(
self,
Expand All @@ -153,7 +154,7 @@ def __init__(
self._attr_name = name
self._state: str | None = None
self._attr_icon = icon
self._set_tracked(entity_ids)
self._entity_ids = entity_ids
jbouwh marked this conversation as resolved.
Show resolved Hide resolved
self._on_off: dict[str, bool] = {}
self._assumed: dict[str, bool] = {}
self._on_states: set[str] = set()
Expand Down Expand Up @@ -287,23 +288,50 @@ def _set_tracked(self, entity_ids: Collection[str] | None) -> None:
if not entity_ids:
self.tracking = ()
self.trackable = ()
self.single_state_type_key = None
return

registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
excluded_domains = registry.exclude_domains

tracking: list[str] = []
trackable: list[str] = []
single_state_type_key: set[tuple[str, str]] = set()
for ent_id in entity_ids:
ent_id_lower = ent_id.lower()
domain = split_entity_id(ent_id_lower)[0]
tracking.append(ent_id_lower)
if domain not in excluded_domains:
trackable.append(ent_id_lower)
if domain in registry.state_group_mapping:
single_state_type_key.add(registry.state_group_mapping[domain])
elif domain == DOMAIN:
# If a group contains another group we check if that group
# has a specific single state type
if ent_id in registry.state_group_mapping:
single_state_type_key.add(registry.state_group_mapping[ent_id])
else:
single_state_type_key.add((STATE_ON, STATE_OFF))

if len(single_state_type_key) == 1:
self.single_state_type_key = next(iter(single_state_type_key))
# To support groups with nested groups we store the state type
# per group entity_id if there is a single state type
registry.state_group_mapping[self.entity_id] = self.single_state_type_key
else:
self.single_state_type_key = None
self.async_on_remove(self._async_deregister)

self.trackable = tuple(trackable)
self.tracking = tuple(tracking)

@callback
def _async_deregister(self) -> None:
"""Deregister group entity from the registry."""
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
if self.entity_id in registry.state_group_mapping:
registry.state_group_mapping.pop(self.entity_id)

@callback
def _async_start(self, _: HomeAssistant | None = None) -> None:
"""Start tracking members and write state."""
Expand Down Expand Up @@ -342,6 +370,7 @@ def async_update_group_state(self) -> None:

async def async_added_to_hass(self) -> None:
"""Handle addition to Home Assistant."""
self._set_tracked(self._entity_ids)
self.async_on_remove(start.async_at_start(self.hass, self._async_start))

async def async_will_remove_from_hass(self) -> None:
Expand Down Expand Up @@ -430,12 +459,14 @@ def _async_update_group_state(self, tr_state: State | None = None) -> None:
# have the same on state we use this state
# and its hass.data[REG_KEY].on_off_mapping to off
if num_on_states == 1:
on_state = list(self._on_states)[0]
on_state = next(iter(self._on_states))
# If we do not have an on state for any domains
# we use None (which will be STATE_UNKNOWN)
elif num_on_states == 0:
self._state = None
return
if self.single_state_type_key:
on_state = self.single_state_type_key[0]
jbouwh marked this conversation as resolved.
Show resolved Hide resolved
# If the entity domains have more than one
# on state, we use STATE_ON/STATE_OFF
else:
Expand All @@ -444,8 +475,11 @@ def _async_update_group_state(self, tr_state: State | None = None) -> None:
if group_is_on:
self._state = on_state
else:
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
self._state = registry.on_off_mapping[on_state]
self._state = (
self.single_state_type_key[1]
if self.single_state_type_key
else STATE_OFF
)
jbouwh marked this conversation as resolved.
Show resolved Hide resolved


def async_get_component(hass: HomeAssistant) -> EntityComponent[Group]:
Expand Down
6 changes: 3 additions & 3 deletions homeassistant/components/group/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .const import DOMAIN, REG_KEY

if TYPE_CHECKING:
from .entity import Group
pass
bdraco marked this conversation as resolved.
Show resolved Hide resolved


async def async_setup(hass: HomeAssistant) -> None:
Expand Down Expand Up @@ -54,7 +54,6 @@ def __init__(self, hass: HomeAssistant) -> None:
self.on_states_by_domain: dict[str, set[str]] = {}
self.exclude_domains: set[str] = set()
self.state_group_mapping: dict[str, tuple[str, str]] = {}
self.group_entities: set[Group] = set()

@callback
def exclude_domain(self, domain: str) -> None:
Expand All @@ -70,7 +69,8 @@ def on_off_states(
if on_state not in self.on_off_mapping:
self.on_off_mapping[on_state] = off_state

if len(on_states) == 1 and off_state not in self.off_on_mapping:
if off_state not in self.off_on_mapping:
MartinHjelmare marked this conversation as resolved.
Show resolved Hide resolved
self.off_on_mapping[off_state] = default_on_state
self.state_group_mapping[domain] = (default_on_state, off_state)

self.on_states_by_domain[domain] = on_states