Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hydra-zen + tyro ❤️ #23

Open
rsokl opened this issue Dec 8, 2022 · 13 comments
Open

hydra-zen + tyro ❤️ #23

rsokl opened this issue Dec 8, 2022 · 13 comments

Comments

@rsokl
Copy link

rsokl commented Dec 8, 2022

Hello! I just came across tyro and it looks great!

I wanted to put hydra-zen on your radar. It is a library designed make Hydra-based projects more Pythonic and lighter on boilerplate code. It mainly does this by providing users with functions like builds and just, which dynamically generate dataclasses that describe how to build, via instantiate, various objects. There are plenty of bells and whistles that I could go into (e.g. nice support for partial'd targets), but I'll keep it brief-ish.

That being said, hydra-zen's main features are quite independent of Hydra, and are more focused on generating dataclasses that can configure/build various Python interfaces. It seems like this might be the sort of thing that could be helpful for tyro users who want to generate nested, typed interfaces based on objects in their library or from third party libraries.

This is just a rough idea at this point, but I figured that there might be some potential synergy here! I'd love to get your impressions if you think there might be any value here.

@brentyi
Copy link
Owner

brentyi commented Dec 9, 2022

Thanks for the note Ryan!

I spent some time today experimenting with hydra-zen and it looks super useful. When populate_full_signature is set to True, defining a config dataclass with builds(), passing it to tyro.cli(), and then instantiating it with instantiate() pretty much also works out of the box. Any specific features you can think of for stress testing the tyro <=> hydra-zen interaction?

My main thought is that it'd be nice to have an equivalent of builds(T) that I can use as a static annotation, eg:

@dataclasses.dataclass
class Config:
    dataloader: Builds[SomeDataloaderClass]
    model: Builds[SomeModel]

tyro.cli(Config)
  • Is there already a way to do something similar? (aside from ConfigStore?)
  • Defining something like BuildsSomeDataloaderClass: TypeAlias = builds(SomeDataloaderClass) will probably actually work for this in its current state and even type check in pyright, but it'd be nice to have something less sketchy.

@rsokl
Copy link
Author

rsokl commented Dec 9, 2022

Hi Brent, I am so glad that you find this to be useful!

My main thought is that it'd be nice to have an equivalent of builds(T) that I can use as a static annotation

If I am understanding this correctly, hydra_zen.typing solves this already 😄. It provides a protocol Builds that does what you are looking for 1:

from dataclasses import dataclass
from typing import TypeVar, Type

from typing_extensions import assert_type

from hydra_zen import instantiate
from hydra_zen.typing import Builds

T = TypeVar("T")

class MyModel: ...

@dataclass
class Config:
    model: Builds[Type[MyModel]]


out = instantiate(Config.model)
assert_type(out, MyModel)  # passes for both mypy and pyright

Does this address what you are looking for?

Any specific features you can think of for stress testing the tyro <=> hydra-zen interaction?

Fortunately, hydra-zen isn't doing anything too magical under the hood beyond generating dataclass types, so I don't suspect there are too many chances for incompatibilities. I will try to find some time later to come up with a list of potential edge cases to test.

I think the main thing to think of are opportunities to boost the ergonomics of tyro + hydra-zen. I would like to spend time playing with tyro + hydra-zen in more realistic settings to see if there any places where hydra-zen's auto-config capabilities could be boot-strapped by tyro to streamline the user experience.

One thing that might be handy is the zen wrapper that will be included in the upcoming release. This will auto-extract & instantiate fields from a config and pass them to a function based on a function's signature. This might save users from having to manually call instantiate on objects.

These are just some initial thoughts. I am happy to discuss/brainstorm further!

Footnotes

  1. The need for Type here is a little unfortunate (but necessary for internal consistency). That being said, hydra_zen can provide a type alias ShortBuilds: TypeAlias = Builds[Callable[..., T]], which I could add to hydra_zen – this would eliminate the need for Type.

    Defining something like BuildsSomeDataloaderClass: TypeAlias = builds(SomeDataloaderClass) will probably actually work for this in its current state and even type check in pyright, but it'd be nice to have something less sketchy.

    Unfortunately, this would not work as builds(...) is not a valid type expression. mypy yells about this. pyright had been yelling about this, but weirdly isn't anymore... which seems like a regressions

@brentyi
Copy link
Owner

brentyi commented Dec 9, 2022

Thanks for the example! I had seen Builds[] but didn't realize it was exposed publicly; maybe the next step is to try and figure what the best way is to help tyro understand it.

Here's a sketchy prototype of the behavior that would be nice:

from dataclasses import dataclass, field
from typing import Type, get_args, get_origin, TYPE_CHECKING

from hydra_zen import builds, instantiate
from hydra_zen.typing import Builds
from typing_extensions import assert_type

import tyro


class MyModel:
    def __init__(self, layers: int, units: int = 32) -> None:
        """Initialize model.

        Args:
            layers: Number of layers.
            units: Number of units.
        """
        pass


if not TYPE_CHECKING:
    # Make Builds[Type[T]] annotations evaluate to builds(T) at runtime.

    def monkey(t):
        if get_origin(t) is type:
            inner_type = get_args(t)[0]
            out = builds(inner_type, populate_full_signature=True)
            out.__doc__ = inner_type.__init__.__doc__  # This will overwrite the current hydra-zen docstring.
            return out
        else:
            return Builds

    Builds.__class_getitem__ = monkey


@dataclass
class Config:
    model: Builds[Type[MyModel]] = field(
        # builds() currently returns a mutable / non-frozen dataclass.
        default_factory=lambda: builds(
            MyModel,
            populate_full_signature=True,
        )(layers=3)
    )


config = tyro.cli(Config)
out = instantiate(config.model)
assert_type(out, MyModel)

image

My guess is the options are:

  • Add a __class_getitem__ patch like this to hydra-zen.
  • Add some hydra-zen-specific support in tyro.
  • Expose some interface for registering custom types / protocols in tyro?

@brentyi
Copy link
Owner

brentyi commented Dec 9, 2022

Also, I imagine you found tyro via nerfstudio? We actually did discuss a pattern for directly mapping args -> config schemas similar to builds(), and there's a very rough nonfunctional prototype of an API in a branch. It's more or less worse version of builds() in hydra-zen.

I ended up having some hesitations about it and the team went with explicit ModelConfig dataclasses for each Model type; concerns here would more or less apply to hydra-zen as well, so if you have thoughts I'd be curious.

The main thing was losing some specific advantages of the explicit approach, where typically the Model would take the config object as an input:

  • Not that we particularly like inheritance, but the external dataclass config makes inheritance much less painful. The alternative, explicitly passing in arguments, means that if a parent class has N arguments and a subclass wants N+1 arguments, either a not-type-safe **parent_kwargs has to be used or those N arguments in the parent's __init__.py need to be repeated in the subclass's __init__.py.
  • When config parameters are only needed in methods like forward(): with an explicit config dataclass, only self.config = config needs to be written in the constructor regardless of the quantity of config parameters. This was nice for eliminating the usual self.num_units = num_units, self.layers = layers, ... boilerplate you see in __init__() implementations.

And then lastly just a desire to reduce magic, since we have a lot of contributors from research backgrounds who were new to types + Python and there were hesitations about learning barriers introduced by things like protocols or dynamically generated classes.

@rsokl
Copy link
Author

rsokl commented Dec 16, 2022

(sorry for the delay! Will get back to you asap)

@rsokl
Copy link
Author

rsokl commented Dec 19, 2022

Here's a sketchy prototype of the behavior that would be nice:

Ah, I see now how tyro would leverage this sort of enhanced Builds type. That's nice! I think it might be best for tyro to expose its own tyro.hydra_zen.Builds & friends protocols that look something like:

# contents of src/tyro/hydra_zen.py
from typing import TYPE_CHECKING, TypeVar, get_args

from typing_extensions import Protocol

from hydra_zen import builds
from hydra_zen.typing import Builds as _Builds

T = TypeVar("T", covariant=True)

class Builds(_Builds[T], Protocol[T]):
    if not TYPE_CHECKING:
        def __get_item__(self, key):
            inner_type = get_args(key)[0]
            out = builds(inner_type, populate_full_signature=True)
            out.__doc__ = inner_type.__init__.__doc__  # This will overwrite the current hydra-zen docstring.
            return out

class PartialBuilds(_PBuilds[T], Protocol[T]):
    if not TYPE_CHECKING:
        def __get_item__(self, key):
            inner_type = get_args(key)[0]
            out = builds(inner_type, populate_full_signature=True, zen_partial=True)
            out.__doc__ = inner_type.__init__.__doc__ 
            return out

(things aren't quite right with these protocols and their relationships with their parents, but there is definitely a solution that will work, I just need to tinker around a little more).

This way tyro users can opt-in to using this feature and you can control the behavior of Builds et al so that it suites tyro. I would recommend against monkey patching Builds itself so that you don't accidentally impact hydra-zen code elsewhere (including runtime performance). I would be happy to help with this and to write tests for / solutions to edge cases (e.g., you can't do builds(dict, populate_full_signature=True) because `dict doesn't possess an inspectable signature).

# builds() currently returns a mutable / non-frozen dataclass.

Just FYI: you can have builds produce a frozen dataclass via the frozen option (which in the next release will be deprecated in favor of builds(..., zen_dataclass={'frozen': True}) -- all make_dataclass/@dataclass options will be available through this argument.)

I ended up having some hesitations about it and the team went with explicit ModelConfig dataclasses for each Model type; concerns here would more or less apply to hydra-zen as well, so if you have thoughts I'd be curious.

Sure!

  • Not that we particularly like inheritance, but the external dataclass config makes inheritance much less painful.
  • When config parameters are only needed in methods like forward(): with an explicit config dataclass, only self.config = config needs to be written in the constructor

It seems like both of these things could be achieved -- with full "explicitness" -- by having each Model type, itself, be a dataclass. This would enable the model itself to have an explicit signature, without the need for boilerplate in the case of inheritance or a trivial constructor. That being said, it isn't trivial to make nn.Module behave like a dataclass (see this issue) (edit: actually, maybe I found a solution! See my later posts). If PyTorch made it easy to support dataclass-based nn.Modules I would 100% recommend using those to keep the model types explicit, and then using something akin to builds to generate corresponding configs for them.

It seems like you are effectively circumventing this by decoupling your inits from your classes via your datclass-based configs. I think it is a reasonable approach. You would indeed incur a whole lot of boilerplate code if you tried to keep everything explicit.

hydra-zen is all about encouraging users to be explicit in their library interfaces, to avoid repetition, and to follow the dictum that "frameworks should be kept at arms length" 1 (e.g., configuration frameworks like Hydra). Projects can typically heed this without issue because they don't have too much inheritance going on, or, they aren't providing a bunch of nn.Module types. That being said, if I were working on nerfstudio, I don't think I would try to persuade anyone to change things at this point, unless you find something like a dataclass+nn.Module solution is viable.

I hope that this was helpful!

Footnotes

  1. The maintainer of Hydra shared this quote with me when he was sharing his perspective hydra-zen and the value that it adds. This really resonates with me.

@rsokl
Copy link
Author

rsokl commented Dec 19, 2022

btw I played with @dataclass + nn.Module some and this does work:

import torch as tr
from dataclasses import dataclass

@dataclass
class Module(tr.nn.Module):
    def __post_init__(self):
        super().__init__()

@dataclass(unsafe_hash=True)
class A(Module):
    x: int
    def __post_init__(self):
        super().__post_init__()
        self.layer1 = tr.nn.Linear(self.x, self.x)

@dataclass(unsafe_hash=True)
class B(A):
    y: int
    def __post_init__(self):
        super().__post_init__()
        self.layer2 = tr.nn.Linear(self.y, self.y)
>>> list(A(1).parameters())
[Parameter containing:
 tensor([[-0.0720]], requires_grad=True),
 Parameter containing:
 tensor([0.3248], requires_grad=True)]
>>> list(B(1, 2).parameters())
[Parameter containing:
 tensor([[0.1868]], requires_grad=True),
 Parameter containing:
 tensor([0.9107], requires_grad=True),
 Parameter containing:
 tensor([[-0.6322,  0.4550],
         [-0.6481,  0.2250]], requires_grad=True),
 Parameter containing:
 tensor([-0.2255,  0.6488], requires_grad=True)]

but there are caveats... this doesn't work if any of your init fields are nn.Modules themselves:

class Evaluators(nn.Module):
    def __init__(self):
        super(Evaluators, self).__init__()
        self.linear = nn.Linear(1, 1)

@dataclass(unsafe_hash=True)
class Net(nn.Module):
    evaluator: Evaluators
    def __post_init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

evaluators = Evaluators()
net = Net(evaluators )
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
c:\Users\rsokl\hydra-zen\scratch\scratch.ipynb Cell 3 in <cell line: 19>()
     [16](vscode-notebook-cell:/c%3A/Users/rsokl/hydra-zen/scratch/scratch.ipynb#Z1015sZmlsZQ%3D%3D?line=15)         self.linear = nn.Linear(1, 1)
     [18](vscode-notebook-cell:/c%3A/Users/rsokl/hydra-zen/scratch/scratch.ipynb#Z1015sZmlsZQ%3D%3D?line=17) evaluators = Evaluators()
---> [19](vscode-notebook-cell:/c%3A/Users/rsokl/hydra-zen/scratch/scratch.ipynb#Z1015sZmlsZQ%3D%3D?line=18) net = Net(evaluators )

File <string>:3, in __init__(self, evaluator)

File c:\Users\rsokl\miniconda3\envs\rai\lib\site-packages\torch\nn\modules\module.py:1236, in Module.__setattr__(self, name, value)
   1234 if isinstance(value, Module):
   1235     if modules is None:
-> 1236         raise AttributeError(
   1237             "cannot assign module before Module.__init__() call")
   1238     remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
   1239     modules[name] = value

AttributeError: cannot assign module before Module.__init__() call

@brentyi
Copy link
Owner

brentyi commented Dec 19, 2022

Thanks for taking the time!

Need some time to think about the rest, but PyTorch3D has some code that attempts something similar: https://pytorch3d.org/tutorials/implicitron_config_system

@rsokl
Copy link
Author

rsokl commented Dec 19, 2022

Actually... I think there is a solution to the nn.Module + dataclass issue! Just use __new__ as a __pre_init__:

import torch as tr
import torch.nn as nn
from dataclasses import dataclass

@dataclass
class DataclassModule(nn.Module):
    def __new__(cls, *args, **k):
        inst = super().__new__(cls)
        nn.Module.__init__(inst)
        return inst

@dataclass(unsafe_hash=True)
class Net(DataclassModule):
    other_layer: nn.Module
    input_feats: int = 10
    output_feats: int = 20

    def __post_init__(self):
        self.layer = nn.Linear(self.input_feats, self.output_feats)

    def forward(self, x):
        return self.layer(self.other_layer(x))

net = Net(other_layer=nn.Linear(10, 10))
assert net(tr.tensor([1.]*10)).shape == (20,)
assert len(list(net.parameters())) == 4

@dataclass(unsafe_hash=True)
class A(DataclassModule):
    x: int
    def __post_init__(self):
        self.layer1 = nn.Linear(self.x, self.x)

@dataclass(unsafe_hash=True)
class B(A):
    y: int
    def __post_init__(self):
        super().__post_init__()
        self.layer2 = nn.Linear(self.y, self.y)

assert len(list(A(1).parameters())) == 2
assert len(list(B(1, 2).parameters())) == 4

@brentyi
Copy link
Owner

brentyi commented Dec 20, 2022

Thanks again for the detailed responses! Helpful + thought-provoking.

On supporting Builds via a tyro.hydra_zen.* module: this makes sense, but I'm hesitant to add hydra_zen-specific things to tyro's user-facing API. Seems like a slightly slippery slope.

I'm leaning toward adding some more general infrastructure in tyro for adding rules for handling hydra_zen.typing.Builds[T] and hydra_zen.typing.PartialBuilds[T] as special cases, some way to tell tyro what to do when it sees these in parsed type annotations.

Probably this could be broken into a function that checks whether an annotation matches a rule and another function that defines how to instantiate these annotations:

# Type[Any] is not quite the correct annotation (Builds[T] is not a real type) but is how `tyro` currently annotates things internally.

def matcher(typ: Type[Any]) -> bool:
	"""Returns true when `typ` is a Builds[T] type."""
	return get_origin(typ) is hydra_zen.typing.Builds

T = TypeVar("T")

def instantiator(typ: Type[T]) -> Callable[..., T]:
	"""Takes a Build[T] protocol type as input, and returns a handler for instantiating the type."""
	assert get_origin(typ) is hydra_zen.typing.Builds
	(inner,) = get_args(type)
	return hydra_zen.builds(inner, populate_full_signature=True)

tyro._registry.register_custom_instantiator(
	matcher,
	instantiator,
)

Does that make sense to you? The API specifics could probably use more thought — potentially it could be used to refactor all of the special handling for things like dataclasses, pydantic, attrs, TypedDict, etc that's currently hardcoded in tyro — and the implementation details could use your input but I think the core idea should work.

On the PyTorch module + dataclass notes: agree with everything you wrote totally; I've been use flax for all my personal things which has a dataclass-based API for defining modules that I miss dearly when switching to PyTorch. The PyTorch module + dataclass implementations are also really cool, although it's a bit scary to be hacking away from super established patterns here. :)

@rsokl
Copy link
Author

rsokl commented Dec 20, 2022

I'm leaning toward adding some more general infrastructure in tyro

Nice! I like the idea of register_custom_instantiator. This is way nicer than hacking Builds. What is really nice about this is that tyro can expose an entrypoint so that 3rd parties can automatically register custom instantiators (users can also manually call register_custom_instantiator). I.e.., hydra-zen can be completely responsible for shipping the instantiators for Builds et al. and exposing a hook for tyro's entrypoint to call. This way, if someone installs both hydra-zen and tyro, they will see that an instantiator for Builds is automatically registered upon importing tyro.

I like this a lot because tyro won't have to ship any hydra-zen specific logic whatsoever! This obviously is ideal because this enables me to make changes to protocols and improvements to the registered instantiators without having to pester you about updating tyro. Naturally, a tyro user should only have to update their hydra-zen version for the latest and greatest support for Builds et al. 😄 .

If you are interested in what this would look like on tyro's end, Hypothesis exposes an entrypoint for its register_strategy function; here is its implementation and here is where it gets called (simply upon importing hypothesis).

On hydra-zen's end, I would add a file like src/hydra_zen/_tyro_hook.py that would look like:

def matcher(typ: Type[Any]) -> bool:
	"""Returns true when `typ` is a Builds[T] type."""
	return get_origin(typ) is hydra_zen.typing.Builds

T = TypeVar("T")

def instantiator(typ: Type[T]) -> Callable[..., T]:
	"""Takes a Build[T] protocol type as input, and returns a handler for instantiating the type."""
	assert get_origin(typ) is hydra_zen.typing.Builds
	(inner,) = get_args(type)
	return hydra_zen.builds(inner, populate_full_signature=True)


def _tyro_setup_hook():
    import tyro
    tyro._registry.register_custom_instantiator(
	matcher,
	instantiator,
    )

and then hydra-zen's pyproject.toml would add something like

entry_points = {"tyro": ["_ = hydra_zen._tyro_hook:_tyro_setup_hook"]}

Obviously, any third party could register instantiators in this way, which is awesome1!

One last thing that I can foresee, is that it would be nice for tyro to make it easy for 3rd parties to test that their entrypoint hooks are working successfully, in an automated way (i.e., without manually checking the output of a CLI). Perhaps there is already a solution that I am not privy to.

Footnotes

  1. Although one bad thing about the register-by-matcher paradigm is that you can't tell what a 3rd party's matcher will match against. Someone could register lambda x: True as their matcher after all.. Perhaps the best thing to do here is to make it easy for users, in some verbose mode, to see whose matcher was triggered in their CLI so that they can identify any 3rd party that is causing issues.

@brentyi
Copy link
Owner

brentyi commented Dec 26, 2022

Great! I've started working on exposing this functionality, but might take a bit because of holidays and such. There are also some design decisions I need to put more time into thinking through related to how flexible the API should be, handling for typing.Annotated and generics, etc.

One last thing that I can foresee, is that it would be nice for tyro to make it easy for 3rd parties to test that their entrypoint hooks are working successfully, in an automated way (i.e., without manually checking the output of a CLI). Perhaps there is already a solution that I am not privy to.

Could this be set up the same way as tyro's existing unit tests, via tyro.cli()'s args= argument? This tells the underlying argparse parser to read from a provided list of strings instead of sys.argv.

Here's what an existing unit test for pydantic compatibility looks like:

import pathlib

import pytest
from pydantic import BaseModel, Field

import tyro


def test_pydantic() -> None:
    class ManyTypesA(BaseModel):
        i: int
        s: str = "hello"
        f: float = Field(default_factory=lambda: 3.0)
        p: pathlib.Path

    # We can directly pass a dataclass to `tyro.cli()`:
    assert tyro.cli(
        ManyTypesA,
        args=[
            "--i",
            "5",
            "--p",
            "~",
        ],
    ) == ManyTypesA(i=5, s="hello", f=3.0, p=pathlib.Path("~"))

@rsokl
Copy link
Author

rsokl commented Dec 30, 2022

Here's what an existing unit test for pydantic compatibility looks like:

Yep. That looks great.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants