Skip to content

Commit

Permalink
Refactor group state logic (#116318)
Browse files Browse the repository at this point in the history
* Refactor group state logic

* Fix

* Add helper and tests for groups with entity platforms multiple ON states

* Adress comments

* Do not store object and avoid linear search

* User dataclass, cleanup multiline ternary

* Add test cases for grouped groups

* Remove dead code

* typo in comment

* Update metjod and module docstr
  • Loading branch information
jbouwh committed May 2, 2024
1 parent 8e7026d commit 41b6886
Show file tree
Hide file tree
Showing 3 changed files with 385 additions and 16 deletions.
44 changes: 38 additions & 6 deletions homeassistant/components/group/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
from typing import Any

from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_ON
from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_OFF, STATE_ON
from homeassistant.core import (
CALLBACK_TYPE,
Event,
Expand All @@ -24,7 +24,7 @@
from homeassistant.helpers.event import async_track_state_change_event

from .const import ATTR_AUTO, ATTR_ORDER, DOMAIN, GROUP_ORDER, REG_KEY
from .registry import GroupIntegrationRegistry
from .registry import GroupIntegrationRegistry, SingleStateType

ENTITY_ID_FORMAT = DOMAIN + ".{}"

Expand Down Expand Up @@ -133,6 +133,7 @@ class Group(Entity):
_attr_should_poll = False
tracking: tuple[str, ...]
trackable: tuple[str, ...]
single_state_type_key: SingleStateType | None

def __init__(
self,
Expand All @@ -153,7 +154,7 @@ def __init__(
self._attr_name = name
self._state: str | None = None
self._attr_icon = icon
self._set_tracked(entity_ids)
self._entity_ids = entity_ids
self._on_off: dict[str, bool] = {}
self._assumed: dict[str, bool] = {}
self._on_states: set[str] = set()
Expand Down Expand Up @@ -287,23 +288,50 @@ def _set_tracked(self, entity_ids: Collection[str] | None) -> None:
if not entity_ids:
self.tracking = ()
self.trackable = ()
self.single_state_type_key = None
return

registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
excluded_domains = registry.exclude_domains

tracking: list[str] = []
trackable: list[str] = []
single_state_type_set: set[SingleStateType] = set()
for ent_id in entity_ids:
ent_id_lower = ent_id.lower()
domain = split_entity_id(ent_id_lower)[0]
tracking.append(ent_id_lower)
if domain not in excluded_domains:
trackable.append(ent_id_lower)
if domain in registry.state_group_mapping:
single_state_type_set.add(registry.state_group_mapping[domain])
elif domain == DOMAIN:
# If a group contains another group we check if that group
# has a specific single state type
if ent_id in registry.state_group_mapping:
single_state_type_set.add(registry.state_group_mapping[ent_id])
else:
single_state_type_set.add(SingleStateType(STATE_ON, STATE_OFF))

if len(single_state_type_set) == 1:
self.single_state_type_key = next(iter(single_state_type_set))
# To support groups with nested groups we store the state type
# per group entity_id if there is a single state type
registry.state_group_mapping[self.entity_id] = self.single_state_type_key
else:
self.single_state_type_key = None
self.async_on_remove(self._async_deregister)

self.trackable = tuple(trackable)
self.tracking = tuple(tracking)

@callback
def _async_deregister(self) -> None:
"""Deregister group entity from the registry."""
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
if self.entity_id in registry.state_group_mapping:
registry.state_group_mapping.pop(self.entity_id)

@callback
def _async_start(self, _: HomeAssistant | None = None) -> None:
"""Start tracking members and write state."""
Expand Down Expand Up @@ -342,6 +370,7 @@ def async_update_group_state(self) -> None:

async def async_added_to_hass(self) -> None:
"""Handle addition to Home Assistant."""
self._set_tracked(self._entity_ids)
self.async_on_remove(start.async_at_start(self.hass, self._async_start))

async def async_will_remove_from_hass(self) -> None:
Expand Down Expand Up @@ -430,22 +459,25 @@ def _async_update_group_state(self, tr_state: State | None = None) -> None:
# have the same on state we use this state
# and its hass.data[REG_KEY].on_off_mapping to off
if num_on_states == 1:
on_state = list(self._on_states)[0]
on_state = next(iter(self._on_states))
# If we do not have an on state for any domains
# we use None (which will be STATE_UNKNOWN)
elif num_on_states == 0:
self._state = None
return
if self.single_state_type_key:
on_state = self.single_state_type_key.on_state
# If the entity domains have more than one
# on state, we use STATE_ON/STATE_OFF
else:
on_state = STATE_ON
group_is_on = self.mode(self._on_off.values())
if group_is_on:
self._state = on_state
elif self.single_state_type_key:
self._state = self.single_state_type_key.off_state
else:
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
self._state = registry.on_off_mapping[on_state]
self._state = STATE_OFF


def async_get_component(hass: HomeAssistant) -> EntityComponent[Group]:
Expand Down
30 changes: 21 additions & 9 deletions homeassistant/components/group/registry.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Provide the functionality to group entities."""
"""Provide the functionality to group entities.
Legacy group support will not be extended for new domains.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Protocol
from dataclasses import dataclass
from typing import Protocol

from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant, callback
Expand All @@ -12,9 +16,6 @@

from .const import DOMAIN, REG_KEY

if TYPE_CHECKING:
from .entity import Group


async def async_setup(hass: HomeAssistant) -> None:
"""Set up the Group integration registry of integration platforms."""
Expand Down Expand Up @@ -43,6 +44,14 @@ def _process_group_platform(
platform.async_describe_on_off_states(hass, registry)


@dataclass(frozen=True, slots=True)
class SingleStateType:
"""Dataclass to store a single state type."""

on_state: str
off_state: str


class GroupIntegrationRegistry:
"""Class to hold a registry of integrations."""

Expand All @@ -53,8 +62,7 @@ def __init__(self, hass: HomeAssistant) -> None:
self.off_on_mapping: dict[str, str] = {STATE_OFF: STATE_ON}
self.on_states_by_domain: dict[str, set[str]] = {}
self.exclude_domains: set[str] = set()
self.state_group_mapping: dict[str, tuple[str, str]] = {}
self.group_entities: set[Group] = set()
self.state_group_mapping: dict[str, SingleStateType] = {}

@callback
def exclude_domain(self, domain: str) -> None:
Expand All @@ -65,12 +73,16 @@ def exclude_domain(self, domain: str) -> None:
def on_off_states(
self, domain: str, on_states: set[str], default_on_state: str, off_state: str
) -> None:
"""Register on and off states for the current domain."""
"""Register on and off states for the current domain.
Legacy group support will not be extended for new domains.
"""
for on_state in on_states:
if on_state not in self.on_off_mapping:
self.on_off_mapping[on_state] = off_state

if len(on_states) == 1 and off_state not in self.off_on_mapping:
if off_state not in self.off_on_mapping:
self.off_on_mapping[off_state] = default_on_state
self.state_group_mapping[domain] = SingleStateType(default_on_state, off_state)

self.on_states_by_domain[domain] = on_states

0 comments on commit 41b6886

Please sign in to comment.