Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thread safety checks to async_create_task #116339

Merged
merged 13 commits into from
Apr 28, 2024
2 changes: 1 addition & 1 deletion homeassistant/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ async def async_setup_multi_components(
# to wait to be imported, and the sooner we can get the base platforms
# loaded the sooner we can start loading the rest of the integrations.
futures = {
domain: hass.async_create_task(
domain: hass.async_create_task_internal(
async_setup_component(hass, domain, config),
f"setup component {domain}",
eager_start=True,
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,7 @@ def async_create_task(

target: target to call.
"""
task = hass.async_create_task(
task = hass.async_create_task_internal(
target, f"{name} {self.title} {self.domain} {self.entry_id}", eager_start
)
if eager_start and task.done():
Expand Down Expand Up @@ -1643,7 +1643,7 @@ async def async_remove(self, entry_id: str) -> dict[str, Any]:
# starting a new flow with the 'unignore' step. If the integration doesn't
# implement async_step_unignore then this will be a no-op.
if entry.source == SOURCE_IGNORE:
self.hass.async_create_task(
self.hass.async_create_task_internal(
self.hass.config_entries.flow.async_init(
entry.domain,
context={"source": SOURCE_UNIGNORE},
Expand Down
37 changes: 35 additions & 2 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,9 @@ def create_task(
target: target to call.
"""
self.loop.call_soon_threadsafe(
functools.partial(self.async_create_task, target, name, eager_start=True)
functools.partial(
self.async_create_task_internal, target, name, eager_start=True
)
)

@callback
Expand All @@ -800,6 +802,37 @@ def async_create_task(
This method must be run in the event loop. If you are using this in your
integration, use the create task methods on the config entry instead.

target: target to call.
"""
# We turned on asyncio debug in April 2024 in the dev containers
# in the hope of catching some of the issues that have been
# reported. It will take a while to get all the issues fixed in
# custom components.
#
# In 2025.5 we should guard the `verify_event_loop_thread`
# check with a check for the `hass.config.debug` flag being set as
# long term we don't want to be checking this in production
# environments since it is a performance hit.
self.verify_event_loop_thread("async_create_task")
return self.async_create_task_internal(target, name, eager_start)

@callback
def async_create_task_internal(
self,
target: Coroutine[Any, Any, _R],
name: str | None = None,
eager_start: bool = True,
) -> asyncio.Task[_R]:
"""Create a task from within the event loop, internal use only.

This method is intended to only be used by core internally
and should not be considered a stable API. We will make
breaking change to this function in the future and it
bdraco marked this conversation as resolved.
Show resolved Hide resolved
should not be used in integrations.

This method must be run in the event loop. If you are using this in your
integration, use the create task methods on the config entry instead.

target: target to call.
"""
if eager_start:
Expand Down Expand Up @@ -2697,7 +2730,7 @@ async def async_call(

coro = self._execute_service(handler, service_call)
if not blocking:
self._hass.async_create_task(
self._hass.async_create_task_internal(
self._run_service_call_catch_exceptions(coro, service_call),
f"service call background {service_call.domain}.{service_call.service}",
eager_start=True,
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/helpers/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ def _async_registry_updated(
is_remove = action == "remove"
self._removed_from_registry = is_remove
if action == "update" or is_remove:
self.hass.async_create_task(
self.hass.async_create_task_internal(
self._async_process_registry_update_or_remove(event), eager_start=True
)

Expand Down
2 changes: 1 addition & 1 deletion homeassistant/helpers/entity_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def async_setup(self, config: ConfigType) -> None:
# Look in config for Domain, Domain 2, Domain 3 etc and load them
for p_type, p_config in conf_util.config_per_platform(config, self.domain):
if p_type is not None:
self.hass.async_create_task(
self.hass.async_create_task_internal(
self.async_setup_platform(p_type, p_config),
f"EntityComponent setup platform {p_type} {self.domain}",
eager_start=True,
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/helpers/entity_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _async_schedule_add_entities(
self, new_entities: Iterable[Entity], update_before_add: bool = False
) -> None:
"""Schedule adding entities for a single platform async."""
task = self.hass.async_create_task(
task = self.hass.async_create_task_internal(
self.async_add_entities(new_entities, update_before_add=update_before_add),
f"EntityPlatform async_add_entities {self.domain}.{self.platform_name}",
eager_start=True,
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/helpers/integration_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _async_integration_platform_component_loaded(

# At least one of the platforms is not loaded, we need to load them
# so we have to fall back to creating a task.
hass.async_create_task(
hass.async_create_task_internal(
_async_process_integration_platforms_for_component(
hass, integration, platforms_that_exist, integration_platforms_by_name
),
Expand Down Expand Up @@ -206,7 +206,7 @@ async def async_process_integration_platforms(
# We use hass.async_create_task instead of asyncio.create_task because
# we want to make sure that startup waits for the task to complete.
#
future = hass.async_create_task(
future = hass.async_create_task_internal(
_async_process_integration_platforms(
hass, platform_name, top_level_components.copy(), process_job
),
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/helpers/intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ async def async_call_service(
)

await self._run_then_background(
hass.async_create_task(
hass.async_create_task_internal(
hass.services.async_call(
domain,
service,
Expand Down
4 changes: 3 additions & 1 deletion homeassistant/helpers/restore_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ async def _async_dump_states(*_: Any) -> None:
# Dump the initial states now. This helps minimize the risk of having
# old states loaded by overwriting the last states once Home Assistant
# has started and the old states have been read.
self.hass.async_create_task(_async_dump_states(), "RestoreStateData dump")
self.hass.async_create_task_internal(
_async_dump_states(), "RestoreStateData dump"
)

# Dump states periodically
cancel_interval = async_track_time_interval(
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/helpers/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ async def _async_call_service_step(self):
)
trace_set_result(params=params, running_script=running_script)
response_data = await self._async_run_long_action(
self._hass.async_create_task(
self._hass.async_create_task_internal(
self._hass.services.async_call(
**params,
blocking=True,
Expand Down Expand Up @@ -1208,7 +1208,7 @@ async def async_run_with_trace(idx: int, script: Script) -> None:
async def _async_run_script(self, script: Script) -> None:
"""Execute a script."""
result = await self._async_run_long_action(
self._hass.async_create_task(
self._hass.async_create_task_internal(
script.async_run(self._variables, self._context), eager_start=True
)
)
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/helpers/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _async_schedule_callback_delayed_write(self) -> None:
# wrote. Reschedule the timer to the next write time.
self._async_reschedule_delayed_write(self._next_write_time)
return
self.hass.async_create_task(
self.hass.async_create_task_internal(
self._async_callback_delayed_write(), eager_start=True
)

Expand Down
2 changes: 1 addition & 1 deletion homeassistant/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ async def when_setup() -> None:
_LOGGER.exception("Error handling when_setup callback for %s", component)

if component in hass.config.components:
hass.async_create_task(
hass.async_create_task_internal(
when_setup(), f"when setup {component}", eager_start=True
)
return
Expand Down
8 changes: 4 additions & 4 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ async def async_test_home_assistant(

orig_async_add_job = hass.async_add_job
orig_async_add_executor_job = hass.async_add_executor_job
orig_async_create_task = hass.async_create_task
orig_async_create_task_internal = hass.async_create_task_internal
orig_tz = dt_util.DEFAULT_TIME_ZONE

def async_add_job(target, *args, eager_start: bool = False):
Expand Down Expand Up @@ -263,18 +263,18 @@ def async_add_executor_job(target, *args):

return orig_async_add_executor_job(target, *args)

def async_create_task(coroutine, name=None, eager_start=True):
def async_create_task_internal(coroutine, name=None, eager_start=True):
"""Create task."""
if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock):
fut = asyncio.Future()
fut.set_result(None)
return fut

return orig_async_create_task(coroutine, name, eager_start)
return orig_async_create_task_internal(coroutine, name, eager_start)

hass.async_add_job = async_add_job
hass.async_add_executor_job = async_add_executor_job
hass.async_create_task = async_create_task
hass.async_create_task_internal = async_create_task_internal

hass.data[loader.DATA_CUSTOM_COMPONENTS] = {}

Expand Down
18 changes: 15 additions & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ async def test_async_create_task_schedule_coroutine() -> None:
async def job():
pass

ha.HomeAssistant.async_create_task(hass, job(), eager_start=False)
ha.HomeAssistant.async_create_task_internal(hass, job(), eager_start=False)
assert len(hass.loop.call_soon.mock_calls) == 0
assert len(hass.loop.create_task.mock_calls) == 1
assert len(hass.add_job.mock_calls) == 0
Expand All @@ -342,7 +342,7 @@ async def test_async_create_task_eager_start_schedule_coroutine() -> None:
async def job():
pass

ha.HomeAssistant.async_create_task(hass, job(), eager_start=True)
ha.HomeAssistant.async_create_task_internal(hass, job(), eager_start=True)
# Should create the task directly since 3.12 supports eager_start
assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.add_job.mock_calls) == 0
Expand All @@ -355,7 +355,7 @@ async def test_async_create_task_schedule_coroutine_with_name() -> None:
async def job():
pass

task = ha.HomeAssistant.async_create_task(
task = ha.HomeAssistant.async_create_task_internal(
hass, job(), "named task", eager_start=False
)
assert len(hass.loop.call_soon.mock_calls) == 0
Expand Down Expand Up @@ -3480,3 +3480,15 @@ async def test_async_remove_thread_safety(hass: HomeAssistant) -> None:
await hass.async_add_executor_job(
hass.services.async_remove, "test_domain", "test_service"
)


async def test_async_create_task_thread_safety(hass: HomeAssistant) -> None:
"""Test async_create_task thread safety."""

async def _any_coro():
pass

with pytest.raises(
RuntimeError, match="Detected code that calls async_create_task from a thread."
):
await hass.async_add_executor_job(hass.async_create_task, _any_coro)