Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple class choices & initialization functions #32

Open
orperel opened this issue Feb 1, 2023 · 9 comments
Open

Multiple class choices & initialization functions #32

orperel opened this issue Feb 1, 2023 · 9 comments

Comments

@orperel
Copy link

orperel commented Feb 1, 2023

Hi tyro team,

Thanks for a great library, tyro is a blast and really helps simplify config code so far!

There is a convoluted scenario I'm still trying to figure out how to apply tyro to:
Foo and Bar are interchangeable classes (i.e. think 2 different dataset implementations), and each can be constructed with different constructors.

class Foo(Baz):
   def __init__(self, a, b, c):
           super().__init__(a, b)
           self.c = c

   @classmethod
   def from_X(cls, arg1, arg2):
        return Foo(arg1, arg2, "X")

   @classmethod
   def from_Y(cls, arg3, arg4):
        return Foo(arg3, arg4, "Y")

class Bar(Baz):
   def __init__(self, a, b, d):
           super().__init__(a, b)
           self.d = d

   @classmethod
   def from_X(cls, arg1, arg2):
        return Foo(arg1, arg2, "another X")

   @classmethod
   def from_Z(cls, arg5, arg6):
        return Foo(arg5, arg6, "Z")

We have a function which generates a dataclass automatically from functions / class__init__.
We want our config system to be able to:

  1. Specify whether a Foo or Bar config should be created, something similar to the Subcommands except we don't have explicit dataclasses (as they're constructed dynamically).
# Goal is to replace those with dynamically created configs, based on the user passing my_baz_impl:foo or my_baz_impl: bar
def main_func(my_baz_impl: Union[Foo, Bar]) -> None:
  1. Accomodate the different from_.. construction methods within the same config, but still allow cli to explicitly show these are separate group args (i.e. Foo.fromX has arg1, arg2 and Foo.fromY has arg3 and arg4).

What's the best practice to approach this with tyro?

Thanks in advance!

@brentyi
Copy link
Owner

brentyi commented Feb 1, 2023

Thanks for giving the library a try!

It won't look as clean as your factory methods, but how do you feel about defining a dataclass to replace each factory method, something like:

@dataclass
class BarFromX:
    arg1: int
    arg2: int

    def instantiate(self) -> Bar:
        return Bar(self.arg1, self.arg2, "X")

And then taking the union over all of them + calling instantiate() when you want your actual Bar or Foo object? It seems like the easiest short-term solution to me.


I've thought about factory methods like what you've described a few times (#30 is an attempt at something related). Basic support in tyro would be really easy, but there have been a bunch of roadblocks to doing it cleanly. As a simple example, a common pattern in tyro is to write a function that looks something like:

def main(
    experiment_name: str = "experiment",
    config: Config = Config(...),
) -> None:
    pass

tyro.cli(main)

When Config is a standard dataclass, it's straightforward to figure out how to parse the default Config() instance. Each field in the dataclass produces an argument, and a default value for each argument can be retrieved by checking the attributes of the default instance.

If Config has a custom constructor associated with it, however, it becomes much harder to figure out what to do with the default value on the right-hand side. Ignore it? Raise an error? Neither of these seem ideal. Let's say we have a factory Config.from_args(arg1: int, arg2: int):

def main(
    experiment_name: str = "experiment",
    config: Config = Config.from_args(arg1=3, arg2=5),
) -> None:
    pass

tyro.cli(main)

Intuitively, this should create a --config.arg1 argument with a default of 3 and a --config.arg2 argument with a default of 5. However, there's no way to reliably / cleanly grab these arguments. This issue gets even harder when we start using subcommands and need to associate the default instance with a subcommand.

Note that hydra-zen (integration discussion in #23) will also almost get you want: a dynamic dataclass from a factory method. But since these are produced dynamically they unfortunately can't (well, shouldn't) be used in type signatures.

@orperel
Copy link
Author

orperel commented Feb 3, 2023

Hi Brent, thanks a lot for the thorough reply!

The idea of combining tyro + hydra-zen turned out great so far :)

With hydra-zen I can generate dynamic configs quickly:

# constructor: Foo.from_X, Foo.from_Y, Bar.from_X, ...
# dynamic_config_class_name: Foo, Bar, ...
cfg = builds(constructor, populate_full_signature=True, zen_dataclass={'cls_name': dynamic_config_class_name})

Then I aggregate them together in a list to define a new TypeVar to create a type signature tyro understands:

# FooBar can combine different ctor configs that can build either Foo or Bar
FooBar = TypeVar(group_name, cfg1, cfg2, cfg3, ...)

def my_main(foobar: FooBar) -> None:
   my_foo_or_bar_instance = instantiate(foobar)  # hydra-zen for cfg -> instance here

tyro.cli(my_main)  # let tyro handle cli

tyro seems to be ok with that as cgf1, cfg2, ... are treated as options for a subcommand (i.e. python example.py foobar:cfg1 --arg1 5 ...). With this both problems are solved: I can support both multiple classes and multiple construction methods!

The only limitation I have so far is somewhat cosmetic. The usage of subcommands means that:

  • If my main has multiple config groups: my_main(foobar: FooBar, set2: AnotherSet, set3: OneMoreSet) I can't view with --help all options for all config sets at once (subcommands are treated in sequence, but would have been cool to see a full table).
  • The actual arg type that my_main will accept for foobar is dynamic dataclass options like Foo, Bar, etc, but I'd have loved to have them as FooConfig, BarConfig to avoid confusion with the original classes those configs represent. I can change the dynamic_config_class_name I give to hydra-zen, but then the subcommands have to specify -config everywhere: i.e. python example.py foobar:foo-config ... instead of just python example.py foobar:foo ....

Are there any hooks in place to customize subcommand / help behaviors?

@brentyi
Copy link
Owner

brentyi commented Feb 3, 2023

Glad you got that working!

Depending on how much you care about writing "correct" type signatures, you might consider replacing the TypeVar with a structure that looks something like this:

from typing import Type, TYPE_CHECKING, Union
from hydra_zen.typing import Builds

if TYPE_CHECKING:
    # For type checking, use a statically analyzable type.
    FooBarConfig = Builds[Type[Union[FooConfig, BarConfig]]]
else:
    # At runtime, use the dynamic dataclasses. This will be what's consumed by `tyro`.
    FooBarConfig = Union[cfg1, cfg2, etc]

This is gross but will fix static type resolution + tab complete for my_foo_or_bar_instance in your example.

For your two questions:

  • Viewing all subcommands in helptext: this is not something I've explored. Would be happy to take a separate GitHub issue or PR if you have time.
    • The helptext formatting is just hacking at argparse under the hood. My feeling is it wouldn't be too hard to get something working, but might be hard to sufficiently polish? This StackOverflow question also seems relevant.
  • Also kind of verbose, but for configuring individual subcommands you can use something like:
    typing.Annotated[
        YourTypeWhichCouldBeFromHydraZen,
        tyro.conf.subcommand(name="shorter_name", prefix_name=False),
    ]

@orperel
Copy link
Author

orperel commented Feb 4, 2023

Thanks again! typing.Annotated works like charm, I have proper classes / subcommand names now! :)
I'll look into the help thingy next.

The TypeVar is a bit more tricky to dispose of, as the example I gave in my previous post is "almost" what happens in practice. I actually ended up using inspect to automatically collect all annotated constructors a class may have (@classmethods that return the class type + init):

dynamic_types = call_inspection_func()  # this one returns [cfg1, cfg2, cfg3...]
FooBar = TypeVar(group_name, *dynamic_types)  # type: ignore
# or alternatively to "fool" the warnings:
# T = TypeVar('T', dynamic_types[0], *dynamic_types[1:])

Union doesn't really like the Asterisk (*) operator, which is why I ended up using TypeVar (which is admittedly hacky).
I managed to convert a TypeVar to a Union with the following, but that likely defeats the purpose of your suggestion:

Union[TypeVar(group_name, *dynamic_types).__constraints__]  # type: ignore

One tradeoff maybe, is to opt for your mode if users explicitly generated dataclasses for everything.
Otherwise, sacrifice static type resolution + tab completion?

@brentyi
Copy link
Owner

brentyi commented Feb 4, 2023

For replacing the TypeVar with a Union, how about Union.__getitem__(tuple(dynamic_types))? My main concern here is that this is not really how TypeVars are meant to be used, and support for TypeVars / generics in tyro is a bit spotty so it might be worth steering clear of them.

Another suggestion is that you could generate the dynamic union type from the statically analyzable one. This might require less boilerplate.

I tried mocking something up for this, which works and seems OK:

from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, Type, Union, get_args, get_origin

import hydra_zen
import tyro
from hydra_zen.typing import Builds
from typing_extensions import Annotated, reveal_type


class Foo:
    @classmethod
    def from_X(cls, a: int, b: int) -> Foo:
        return Foo()

    @classmethod
    def from_Y(cls, c: int, d: int) -> Foo:
        return Foo()


class Bar:
    @classmethod
    def from_X(cls, a: int, b: int) -> Foo:
        return Foo()

    @classmethod
    def from_Y(cls, c: int, d: int) -> Foo:
        return Foo()


def dynamic_union_from_static_union(typ: Type[Builds[Type]]) -> Any:
    # Builds[Type[Foo | Bar]] => Type[Foo | Bar]
    (typ,) = get_args(typ)

    # Type[Foo | Bar] => Foo | Bar
    assert get_origin(typ) is type
    (union_type,) = get_args(typ)

    # Foo | Bar => Foo, Bar
    config_types = get_args(union_type)

    # Get constructors.
    constructors = []
    for config_type in config_types:
        constructors.extend(
            [
                method
                for name, method in inspect.getmembers(
                    config_type, predicate=inspect.ismethod
                )
                if name.startswith("from_")
                and hasattr(method, "__self__")
                and method.__self__ is config_type
            ]
        )

    # Return union over dynamic dataclasses, one for each constructor type.
    return Union.__getitem__(  # type: ignore
        tuple(
            Annotated[
                # Create the dynamic dataclass.
                hydra_zen.builds(c, populate_full_signature=True),
                # Rename the subcommand.
                tyro.conf.subcommand(
                    c.__self__.__name__.lower() + "_" + c.__name__.lower(),
                    prefix_name=False,
                ),
            ]
            for c in constructors
        )
    )


Config = Builds[Type[Union[Foo, Bar]]]
if not TYPE_CHECKING:
    Config = dynamic_union_from_static_union(Config)


def main(config: Config) -> None:
    # Should resolve  to `Bar | Foo`.
    reveal_type(hydra_zen.instantiate(config))


if __name__ == "__main__":
    tyro.cli(main)

Documentation and __init__ support is left as an exercise to the reader. 🙂

For the helptext stuff, I'm guessing you could figure this out yourself, but the custom argparse formatter is probably what you want to look at!

@orperel
Copy link
Author

orperel commented Feb 21, 2023

@brentyi Thanks for all the useful advice again!

I finally took care of the flat --help mode.
It's somewhat brittle, but here is a draft which currently takes care of it outside of tyro:

class TyroFlatSubcommandHelpFormatter(tyro._argparse_formatter.TyroArgparseHelpFormatter):
    def add_usage(self, usage, actions, groups, prefix=None):
        aggregated_subcommand_group = []
        for action_name, sub_parser in self.collect_subcommands_parsers(actions).items():
            for sub_action_group in sub_parser._action_groups:
                sub_group_actions = sub_action_group._group_actions
                if len(sub_group_actions) > 0:
                    is_subparser_action = lambda x: isinstance(x, argparse._SubParsersAction)
                    is_help_action = lambda x: isinstance(x, argparse._HelpAction)
                    if any([is_subparser_action(a) and not is_help_action(a) for a in sub_group_actions]):
                        aggregated_subcommand_group.append(sub_action_group)

        # Remove duplicate subcommand parsers
        aggregated_subcommand_group = list({a._group_actions[0].metavar: a
                                            for a in aggregated_subcommand_group}.values())
        next_actions = [g._group_actions[0] for g in aggregated_subcommand_group]
        actions.extend(next_actions)
        super().add_usage(usage, actions, groups, prefix)

    def add_arguments(self, action_group):
        if len(action_group) > 0 and action_group[0].container.title == 'subcommands':
            # If a subcommands action group - rename first subcommand (for which this function was invoked)
            choices_header = next(iter(action_group[0].choices))
            choices_title = choices_header.split(':')[0] + ' choices'
            action_group[0].container.title = choices_title
            self._current_section.heading = choices_title  # Formatter have already set a section, override heading

        # Invoke default
        super().add_arguments(action_group)

        aggregated_action_group = []
        aggregated_subcommand_group = []
        for action in action_group:
            if not isinstance(action, argparse._SubParsersAction):
                continue
            for action_name, sub_parser in self.collect_subcommands_parsers([action]).items():
                sub_parser.formatter_class = self
                for sub_action_group in sub_parser._action_groups:
                    sub_group_actions = sub_action_group._group_actions
                    if len(sub_group_actions) > 0:
                        is_subparser_action = lambda x: isinstance(x, argparse._SubParsersAction)
                        is_help_action = lambda x: isinstance(x, argparse._HelpAction)
                        if any([not is_subparser_action(a) and not is_help_action(a) for a in sub_group_actions]):
                            for a in sub_group_actions:
                                a.container.title = action_name + ' arguments'
                            aggregated_action_group.append(sub_action_group)
                        elif any([not is_help_action(a) for a in sub_group_actions]):
                            for a in sub_group_actions:
                                choices_header = next(iter(sub_group_actions[0].choices))
                                a.container.title = choices_header.split(':')[0] + ' choices'
                            aggregated_subcommand_group.append(sub_action_group)

        # Remove duplicate subcommand parsers
        aggregated_subcommand_group = list({a._group_actions[0].metavar: a
                                            for a in aggregated_subcommand_group}.values())
        for aggregated_group in (aggregated_subcommand_group, aggregated_action_group):
            for next_action_group in aggregated_group:
                self.end_section()
                self.start_section(next_action_group.title)
                self.add_text(next_action_group.description)
                super().add_arguments(next_action_group._group_actions)

    def collect_subcommands_parsers(self, actions):
        collected_titles = list()
        collected_subparsers = list()
        parsers = list()

        def _handle_actions(_actions):
            action_choices = [action.choices for action in _actions if isinstance(action, argparse._SubParsersAction)]
            for choices in action_choices:
                for subcommand, subcommand_parser in choices.items():
                    collected_titles.append(subcommand)
                    collected_subparsers.append(subcommand_parser)
                    parsers.append(subcommand_parser)

        _handle_actions(actions)
        while parsers:
            parser = parsers.pop(0)
            _handle_actions(parser._actions)

        # Eliminate duplicates and preserve order (dicts are guaranteed to preserve insertion order from python >=3.7)
        return dict(zip(collected_titles, collected_subparsers))

I can test it like this:

    parser = tyro.extras.get_parser(Config)
    parser.formatter_class = TyroFlatSubcommandHelpFormatter
    parser.print_help()

but tyro.cli() doesn't accept a custom formatter, so I guess at the very least we'll have to introduce some hook?

@brentyi
Copy link
Owner

brentyi commented Feb 21, 2023

Cool!

Yeah, I guess the hacky short-term solution is a monkey patch?

tyro._argparse_formatter.TyroArgparseHelpFormatter = TyroFlatSubcommandHelpFormatter
tyro.cli(...)

Accepting + supporting custom formatters seems like a can of worms that I'm not sure we want to open...!

@orperel
Copy link
Author

orperel commented Feb 21, 2023

That makes sense!
Could we come up with a more future-proof solution on tyro's side? Something like -

If --help and a "FlatSubparsers" tyro marker is passed, execute this extra logic:

  1. Invoke collect_subcommands_parsers() from the snippet above
  2. Iterate similar to add_arguments() above to collect unique argument groups, carefully avoiding dups
  3. Add all subparser groups and then store action groups to the parser

@brentyi
Copy link
Owner

brentyi commented Feb 22, 2023

That makes sense!

A marker makes sense, but without documentation would imply some level of fine-grained control if we have a deeply nested subcommand tree and apply the annotation at an intermediate level, for example. Is this possible to implement?

For these aesthetic things like this global state might also be okay, like we currently have a tyro.extras.set_accent_color():

https://brentyi.github.io/tyro/api/tyro/extras/#tyro.extras.set_accent_color

We could broaden this a bit into something like tyro.extras.configure_helptext() that could give us more fine-grained control over colors, subcommand flattening, etc?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants