Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor group state logic #116318

Merged
merged 11 commits into from
May 2, 2024
44 changes: 38 additions & 6 deletions homeassistant/components/group/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
from typing import Any

from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_ON
from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_OFF, STATE_ON
from homeassistant.core import (
CALLBACK_TYPE,
Event,
Expand All @@ -24,7 +24,7 @@
from homeassistant.helpers.event import async_track_state_change_event

from .const import ATTR_AUTO, ATTR_ORDER, DOMAIN, GROUP_ORDER, REG_KEY
from .registry import GroupIntegrationRegistry
from .registry import GroupIntegrationRegistry, SingleStateType

ENTITY_ID_FORMAT = DOMAIN + ".{}"

Expand Down Expand Up @@ -133,6 +133,7 @@ class Group(Entity):
_attr_should_poll = False
tracking: tuple[str, ...]
trackable: tuple[str, ...]
single_state_type_key: SingleStateType | None

def __init__(
self,
Expand All @@ -153,7 +154,7 @@ def __init__(
self._attr_name = name
self._state: str | None = None
self._attr_icon = icon
self._set_tracked(entity_ids)
self._entity_ids = entity_ids
jbouwh marked this conversation as resolved.
Show resolved Hide resolved
self._on_off: dict[str, bool] = {}
self._assumed: dict[str, bool] = {}
self._on_states: set[str] = set()
Expand Down Expand Up @@ -287,23 +288,50 @@ def _set_tracked(self, entity_ids: Collection[str] | None) -> None:
if not entity_ids:
self.tracking = ()
self.trackable = ()
self.single_state_type_key = None
return

registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
excluded_domains = registry.exclude_domains

tracking: list[str] = []
trackable: list[str] = []
single_state_type_set: set[SingleStateType] = set()
for ent_id in entity_ids:
ent_id_lower = ent_id.lower()
domain = split_entity_id(ent_id_lower)[0]
tracking.append(ent_id_lower)
if domain not in excluded_domains:
trackable.append(ent_id_lower)
if domain in registry.state_group_mapping:
single_state_type_set.add(registry.state_group_mapping[domain])
elif domain == DOMAIN:
# If a group contains another group we check if that group
# has a specific single state type
if ent_id in registry.state_group_mapping:
single_state_type_set.add(registry.state_group_mapping[ent_id])
else:
single_state_type_set.add(SingleStateType(STATE_ON, STATE_OFF))

if len(single_state_type_set) == 1:
self.single_state_type_key = next(iter(single_state_type_set))
# To support groups with nested groups we store the state type
# per group entity_id if there is a single state type
registry.state_group_mapping[self.entity_id] = self.single_state_type_key
else:
self.single_state_type_key = None
self.async_on_remove(self._async_deregister)

self.trackable = tuple(trackable)
self.tracking = tuple(tracking)

@callback
def _async_deregister(self) -> None:
"""Deregister group entity from the registry."""
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
if self.entity_id in registry.state_group_mapping:
registry.state_group_mapping.pop(self.entity_id)

@callback
def _async_start(self, _: HomeAssistant | None = None) -> None:
"""Start tracking members and write state."""
Expand Down Expand Up @@ -342,6 +370,7 @@ def async_update_group_state(self) -> None:

async def async_added_to_hass(self) -> None:
"""Handle addition to Home Assistant."""
self._set_tracked(self._entity_ids)
self.async_on_remove(start.async_at_start(self.hass, self._async_start))

async def async_will_remove_from_hass(self) -> None:
Expand Down Expand Up @@ -430,22 +459,25 @@ def _async_update_group_state(self, tr_state: State | None = None) -> None:
# have the same on state we use this state
# and its hass.data[REG_KEY].on_off_mapping to off
if num_on_states == 1:
on_state = list(self._on_states)[0]
on_state = next(iter(self._on_states))
# If we do not have an on state for any domains
# we use None (which will be STATE_UNKNOWN)
elif num_on_states == 0:
self._state = None
return
if self.single_state_type_key:
on_state = self.single_state_type_key.on_state
# If the entity domains have more than one
# on state, we use STATE_ON/STATE_OFF
else:
on_state = STATE_ON
group_is_on = self.mode(self._on_off.values())
if group_is_on:
self._state = on_state
elif self.single_state_type_key:
self._state = self.single_state_type_key.off_state
else:
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
self._state = registry.on_off_mapping[on_state]
self._state = STATE_OFF


def async_get_component(hass: HomeAssistant) -> EntityComponent[Group]:
Expand Down
30 changes: 21 additions & 9 deletions homeassistant/components/group/registry.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Provide the functionality to group entities."""
"""Provide the functionality to group entities.

Legacy group support will not be extended for new domains.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Protocol
from dataclasses import dataclass
from typing import Protocol

from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant, callback
Expand All @@ -12,9 +16,6 @@

from .const import DOMAIN, REG_KEY

if TYPE_CHECKING:
from .entity import Group


async def async_setup(hass: HomeAssistant) -> None:
"""Set up the Group integration registry of integration platforms."""
Expand Down Expand Up @@ -43,6 +44,14 @@ def _process_group_platform(
platform.async_describe_on_off_states(hass, registry)


@dataclass(frozen=True, slots=True)
class SingleStateType:
"""Dataclass to store a single state type."""

on_state: str
off_state: str


class GroupIntegrationRegistry:
"""Class to hold a registry of integrations."""

Expand All @@ -53,8 +62,7 @@ def __init__(self, hass: HomeAssistant) -> None:
self.off_on_mapping: dict[str, str] = {STATE_OFF: STATE_ON}
self.on_states_by_domain: dict[str, set[str]] = {}
self.exclude_domains: set[str] = set()
self.state_group_mapping: dict[str, tuple[str, str]] = {}
self.group_entities: set[Group] = set()
self.state_group_mapping: dict[str, SingleStateType] = {}

@callback
def exclude_domain(self, domain: str) -> None:
Expand All @@ -65,12 +73,16 @@ def exclude_domain(self, domain: str) -> None:
def on_off_states(
self, domain: str, on_states: set[str], default_on_state: str, off_state: str
) -> None:
"""Register on and off states for the current domain."""
"""Register on and off states for the current domain.

Legacy group support will not be extended for new domains.
"""
for on_state in on_states:
if on_state not in self.on_off_mapping:
self.on_off_mapping[on_state] = off_state

if len(on_states) == 1 and off_state not in self.off_on_mapping:
if off_state not in self.off_on_mapping:
MartinHjelmare marked this conversation as resolved.
Show resolved Hide resolved
self.off_on_mapping[off_state] = default_on_state
self.state_group_mapping[domain] = SingleStateType(default_on_state, off_state)

self.on_states_by_domain[domain] = on_states