Skip to content

Commit

Permalink
Add new switch task type for running different versions of a task #83
Browse files Browse the repository at this point in the history
Also clean up documentation of task types.
  • Loading branch information
nat-n committed Dec 24, 2022
1 parent 5951507 commit 3224e6b
Show file tree
Hide file tree
Showing 13 changed files with 771 additions and 206 deletions.
504 changes: 336 additions & 168 deletions README.rst

Large diffs are not rendered by default.

26 changes: 18 additions & 8 deletions poethepoet/context.py
Expand Up @@ -58,26 +58,36 @@ def get_task_env(

# Include env vars from dependencies
if task_uses is not None:
result.update(self.get_dep_values(task_uses))
result.update(self._get_dep_values(task_uses))

return result

def get_dep_values(
def _get_dep_values(
self, used_task_invocations: Mapping[str, Tuple[str, ...]]
) -> Dict[str, str]:
"""
Get env vars from upstream tasks declared via the uses option.
New lines are replaced with whitespace similar to how unquoted command
interpolation works in bash.
"""
return {
var_name: re.sub(
r"\s+", " ", self.captured_stdout[invocation].strip("\r\n")
)
var_name: self.get_task_output(invocation)
for var_name, invocation in used_task_invocations.items()
}

def save_task_output(self, invocation: Tuple[str, ...], captured_stdout: bytes):
"""
Store the stdout data from a task so that it can be reused by other tasks
"""
self.captured_stdout[invocation] = captured_stdout.decode()

def get_task_output(self, invocation: Tuple[str, ...]):
"""
Get the stored stdout data from a task so that it can be reused by other tasks
New lines are replaced with whitespace similar to how unquoted command
interpolation works in bash.
"""
return re.sub(r"\s+", " ", self.captured_stdout[invocation].strip("\r\n"))

def get_executor(
self,
invocation: Tuple[str, ...],
Expand Down
15 changes: 12 additions & 3 deletions poethepoet/env/manager.py
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union

from .cache import EnvFileCache
from .template import apply_envvars_to_template
Expand Down Expand Up @@ -91,8 +91,17 @@ def for_task(

return result

def update(self, env_vars: Mapping[str, str]):
self._vars.update(env_vars)
def update(self, env_vars: Mapping[str, Any]):
# ensure all values are strings
str_vars: Dict[str, str] = {}
for key, value in env_vars.items():
if isinstance(value, list):
str_vars[key] = " ".join(str(item) for item in value)
elif value is not None:
str_vars[key] = str(value)

self._vars.update(str_vars)

return self

def to_dict(self):
Expand Down
2 changes: 1 addition & 1 deletion poethepoet/executor/base.py
Expand Up @@ -217,7 +217,7 @@ def handle_signal(signum, _frame):
(captured_stdout, _) = proc.communicate(input)

if self.capture_stdout == True:
self.context.captured_stdout[self.invocation] = captured_stdout.decode()
self.context.save_task_output(self.invocation, captured_stdout)

# restore signal handler
signal.signal(signal.SIGINT, old_signal_handler)
Expand Down
7 changes: 3 additions & 4 deletions poethepoet/helpers/python.py
Expand Up @@ -31,6 +31,7 @@
"next",
"oct",
"ord",
"os",
"pow",
"repr",
"round",
Expand Down Expand Up @@ -58,6 +59,7 @@
"set",
"slice",
"str",
"sys",
"tuple",
"type",
"zip",
Expand Down Expand Up @@ -95,14 +97,11 @@ def resolve_function_call(
),
)
for node in name_nodes:
if node.id in _BUILTINS_WHITELIST:
# builtin values have precedence over unqualified args
continue
if node.id in arguments:
substitutions.append(
(_get_name_node_abs_range(source, node), args_prefix + node.id)
)
else:
elif node.id not in _BUILTINS_WHITELIST:
raise ScriptParseError(
"Invalid variable reference in script: "
+ _get_name_source_segment(source, node)
Expand Down
1 change: 1 addition & 0 deletions poethepoet/task/__init__.py
Expand Up @@ -4,3 +4,4 @@
from .script import ScriptTask
from .sequence import SequenceTask
from .shell import ShellTask
from .switch import SwitchTask
14 changes: 1 addition & 13 deletions poethepoet/task/base.py
Expand Up @@ -209,26 +209,14 @@ def _parse_named_args(
return PoeTaskArgs(args_def, self.name, env).parse(extra_args)
return None

# @property
# def has_named_args(self):
# return bool(self.named_args)

def get_named_arg_values(self, env: EnvVarsManager) -> Dict[str, str]:
result: Dict[str, str] = {}

if self.named_args is None:
self.named_args = self._parse_named_args(self.invocation[1:], env)

if not self.named_args:
return {}

for key, value in self.named_args.items():
if isinstance(value, list):
result[key] = " ".join(str(item) for item in value)
elif value is not None:
result[key] = str(value)

return result
return self.named_args

def run(
self,
Expand Down
8 changes: 6 additions & 2 deletions poethepoet/task/ref.py
Expand Up @@ -28,8 +28,12 @@ def _handle_run(
"""
Lookup and delegate to the referenced task
"""
invocation = tuple(shlex.split(env.fill_template(self.content.strip())))
task = self.from_config(invocation[0], self._config, self._ui, invocation)
self.ref_invocation = tuple(
shlex.split(env.fill_template(self.content.strip()))
)
task = self.from_config(
self.ref_invocation[0], self._config, self._ui, self.ref_invocation
)
return task.run(context=context, extra_args=extra_args, parent_env=env)

@classmethod
Expand Down
13 changes: 8 additions & 5 deletions poethepoet/task/script.py
Expand Up @@ -32,20 +32,23 @@ def _handle_run(
extra_args: Sequence[str],
env: EnvVarsManager,
) -> int:
# TODO: check whether the project really does use src layout, and don't do
# sys.path.append('src') if it doesn't
target_module, function_call = self.parse_script_content(self.named_args)
named_arg_values = self.get_named_arg_values(env)
env.update(named_arg_values)
target_module, function_call = self.parse_script_content(named_arg_values)
argv = [
self.name,
*(env.fill_template(token) for token in extra_args),
]

# TODO: check whether the project really does use src layout, and don't do
# sys.path.append('src') if it doesn't

script = [
"import sys; ",
"import os,sys; ",
"from os import environ; ",
"from importlib import import_module; ",
f"sys.argv = {argv!r}; sys.path.append('src');",
f"{self.format_args_class(self.named_args)}",
f"{self.format_args_class(named_arg_values)}",
f"result = import_module('{target_module}').{function_call};",
]

Expand Down
196 changes: 196 additions & 0 deletions poethepoet/task/switch.py
@@ -0,0 +1,196 @@
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)

from ..env.manager import EnvVarsManager
from ..exceptions import ExecutionError, PoeException
from .base import PoeTask, TaskContent

if TYPE_CHECKING:
from ..config import PoeConfig
from ..context import RunContext
from ..executor import PoeExecutor
from ..ui import PoeUi


DEFAULT_CASE = "__default__"


class SwitchTask(PoeTask):
"""
A task that runs one of several `case` subtasks depending on the output of a
`switch` subtask.
"""

content: List[Union[str, Dict[str, Any]]]

__key__ = "switch"
__content_type__: Type = list
__options__: Dict[str, Union[Type, Tuple[Type, ...]]] = {
"control": (str, dict),
"default": str,
}

def __init__(
self,
name: str,
content: TaskContent,
options: Dict[str, Any],
ui: "PoeUi",
config: "PoeConfig",
invocation: Tuple[str, ...],
capture_stdout: bool = False,
):
super().__init__(name, content, options, ui, config, invocation)

control_task_name = f"{name}[control]"
control_invocation: Tuple[str, ...] = (control_task_name,)
if self.options.get("args"):
self.options["control"]["args"] = self.options["args"]
control_invocation = (*control_invocation, *invocation[1:])

self.control_task = self.from_def(
task_def=self.options.get("control", ""),
task_name=control_task_name,
config=config,
invocation=control_invocation,
ui=ui,
capture_stdout=True,
)

self.switch_tasks = {}
for item in cast(List[Dict[str, Any]], content):
task_def = {key: value for key, value in item.items() if key != "case"}

task_invocation: Tuple[str, ...] = (name,)
if self.options.get("args"):
task_def["args"] = self.options["args"]
task_invocation = (*task_invocation, *invocation[1:])

for case_key in self._get_case_keys(item):
self.switch_tasks[case_key] = self.from_def(
task_def=task_def,
task_name=f"{name}__{case_key}",
config=config,
invocation=task_invocation,
ui=ui,
)

def _handle_run(
self,
context: "RunContext",
extra_args: Sequence[str],
env: EnvVarsManager,
) -> int:
named_arg_values = self.get_named_arg_values(env)
env.update(named_arg_values)

if not named_arg_values and any(arg.strip() for arg in extra_args):
raise PoeException(f"Switch task {self.name!r} does not accept arguments")

# Indicate on the global context that there are multiple stages to this task
context.multistage = True

task_result = self.control_task.run(
context=context,
extra_args=extra_args if self.options.get("args") else tuple(),
parent_env=env,
)
if task_result:
raise ExecutionError(
f"Switch task {self.name!r} aborted after failed control task"
)

control_task_output = context.get_task_output(self.control_task.invocation)
case_task = self.switch_tasks.get(
control_task_output, self.switch_tasks.get(DEFAULT_CASE)
)

if case_task is None:
if self.options.get("default", "fail") == "pass":
return 0
raise ExecutionError(
f"Control value {control_task_output!r} did not match any cases in "
f"switch task {self.name!r}."
)

return case_task.run(context=context, extra_args=extra_args, parent_env=env)

@classmethod
def _get_case_keys(cls, task_def: Dict[str, Any]) -> List[Any]:
case_value = task_def.get("case", DEFAULT_CASE)
if isinstance(case_value, list):
return case_value
return [case_value]

@classmethod
def _validate_task_def(
cls, task_name: str, task_def: Dict[str, Any], config: "PoeConfig"
) -> Optional[str]:

control_task_def = task_def.get("control")
if not control_task_def:
return f"Switch task {task_name!r} has no control task."

allowed_control_task_types = ("cmd", "script")
if isinstance(control_task_def, dict) and not any(
key in control_task_def for key in allowed_control_task_types
):
return (
f"Control task for {task_name!r} must have a type that is one of "
f"{allowed_control_task_types!r}"
)

cases: MutableMapping[Any, int] = defaultdict(int)
for switch_task in task_def["switch"]:
for case_key in cls._get_case_keys(switch_task):
cases[case_key] += 1

for invalid_option in ("args", "deps"):
if invalid_option in switch_task:
case_key = switch_task.get("case", DEFAULT_CASE)
if case_key is DEFAULT_CASE:
return (
f"Default case of switch task {task_name!r} includes "
f"invalid option {invalid_option!r}"
)
return (
f"Case {case_key!r} switch task {task_name!r} include invalid "
f"option {invalid_option!r}"
)

for case, count in cases.items():
if count > 1:
if case is DEFAULT_CASE:
return (
f"Switch task {task_name!r} includes more than one default case"
)
return (
f"Switch task {task_name!r} includes more than one case for "
f"{case!r}"
)

if "default" in task_def:
if task_def["default"] not in ("pass", "fail"):
return (
f"The default option for switch task {task_name!r} should be one "
"of ('pass', 'fail')"
)
if DEFAULT_CASE in cases:
return (
f"Switch task {task_name!r} should not have both a default case "
f"and the default option."
)

return None

0 comments on commit 3224e6b

Please sign in to comment.