Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor group state logic #116318

Merged
merged 11 commits into from
May 2, 2024
40 changes: 36 additions & 4 deletions homeassistant/components/group/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
from typing import Any

from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_ON
from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_OFF, STATE_ON
from homeassistant.core import (
CALLBACK_TYPE,
Event,
Expand Down Expand Up @@ -133,6 +133,7 @@ class Group(Entity):
_attr_should_poll = False
tracking: tuple[str, ...]
trackable: tuple[str, ...]
single_state_type_key: tuple[str, str] | None

def __init__(
self,
Expand Down Expand Up @@ -287,23 +288,49 @@ def _set_tracked(self, entity_ids: Collection[str] | None) -> None:
if not entity_ids:
self.tracking = ()
self.trackable = ()
self.single_state_type_key = None
return

registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
excluded_domains = registry.exclude_domains

tracking: list[str] = []
trackable: list[str] = []
single_state_type_key: set[tuple[str, str]] = set()
for ent_id in entity_ids:
ent_id_lower = ent_id.lower()
domain = split_entity_id(ent_id_lower)[0]
tracking.append(ent_id_lower)
if domain not in excluded_domains:
trackable.append(ent_id_lower)
if domain in registry.state_group_mapping:
single_state_type_key.add(registry.state_group_mapping[domain])
elif domain == DOMAIN:
for group in iter(registry.group_entities):
jbouwh marked this conversation as resolved.
Show resolved Hide resolved
if group.entity_id == ent_id and group.single_state_type_key:
jbouwh marked this conversation as resolved.
Show resolved Hide resolved
single_state_type_key.add(group.single_state_type_key)
break
else:
single_state_type_key.add((STATE_ON, STATE_OFF))

self.single_state_type_key = (
next(iter(single_state_type_key))
if len(single_state_type_key) == 1
else None
)
jbouwh marked this conversation as resolved.
Show resolved Hide resolved
registry.group_entities.add(self)
jbouwh marked this conversation as resolved.
Show resolved Hide resolved
self.async_on_remove(self._deregister)
jbouwh marked this conversation as resolved.
Show resolved Hide resolved

self.trackable = tuple(trackable)
self.tracking = tuple(tracking)

@callback
def _deregister(self) -> None:
jbouwh marked this conversation as resolved.
Show resolved Hide resolved
"""Deregister group entity from the registry."""
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
if self in registry.group_entities:
registry.group_entities.remove(self)

@callback
def _async_start(self, _: HomeAssistant | None = None) -> None:
"""Start tracking members and write state."""
Expand Down Expand Up @@ -430,12 +457,14 @@ def _async_update_group_state(self, tr_state: State | None = None) -> None:
# have the same on state we use this state
# and its hass.data[REG_KEY].on_off_mapping to off
if num_on_states == 1:
on_state = list(self._on_states)[0]
on_state = next(iter(self._on_states))
# If we do not have an on state for any domains
# we use None (which will be STATE_UNKNOWN)
elif num_on_states == 0:
self._state = None
return
if self.single_state_type_key:
on_state = self.single_state_type_key[0]
jbouwh marked this conversation as resolved.
Show resolved Hide resolved
# If the entity domains have more than one
# on state, we use STATE_ON/STATE_OFF
else:
Expand All @@ -444,8 +473,11 @@ def _async_update_group_state(self, tr_state: State | None = None) -> None:
if group_is_on:
self._state = on_state
else:
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
self._state = registry.on_off_mapping[on_state]
self._state = (
self.single_state_type_key[1]
if self.single_state_type_key
else STATE_OFF
)
jbouwh marked this conversation as resolved.
Show resolved Hide resolved


def async_get_component(hass: HomeAssistant) -> EntityComponent[Group]:
Expand Down
3 changes: 2 additions & 1 deletion homeassistant/components/group/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def on_off_states(
if on_state not in self.on_off_mapping:
self.on_off_mapping[on_state] = off_state

if len(on_states) == 1 and off_state not in self.off_on_mapping:
if off_state not in self.off_on_mapping:
MartinHjelmare marked this conversation as resolved.
Show resolved Hide resolved
self.off_on_mapping[off_state] = default_on_state
self.state_group_mapping[domain] = (default_on_state, off_state)

self.on_states_by_domain[domain] = on_states