Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AutoTransitionCriterion #2409

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 20 additions & 1 deletion ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _pick_fitted_model_to_gen_from(self) -> ModelSpec:
# ------------------------- Trial logic helpers. -------------------------
@property
def trials_from_node(self) -> Set[int]:
"""Returns a dictionary mapping a GenerationNode to the trials it generated.
"""Returns a set mapping a GenerationNode to the trials it generated.

Returns:
Set[int]: A set containing all the trials indices generated by this node.
Expand All @@ -384,6 +384,19 @@ def trials_from_node(self) -> Set[int]:
trials_from_node.add(trial.index)
return trials_from_node

@property
def node_that_generated_last_gr(self) -> Optional[str]:
"""Returns the name of the node that generated the last generator run.

Returns:
str: The name of the node that generated the last generator run.
"""
return (
self.generation_strategy.last_generator_run._generation_node_name
if self.generation_strategy.last_generator_run
else None
)

def should_transition_to_next_node(
self, raise_data_required_error: bool = True
) -> Tuple[bool, Optional[str]]:
Expand All @@ -409,6 +422,12 @@ def should_transition_to_next_node(
tc.is_met(
experiment=self.experiment,
trials_from_node=self.trials_from_node,
curr_node_name=self.node_name,
node_that_generated_last_gr=(
self.generation_strategy.last_generator_run._generation_node_name
if self.generation_strategy.last_generator_run is not None
else None
),
)
for tc in transition_blocking
)
Expand Down
113 changes: 113 additions & 0 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.transition_criterion import (
AutoTransitionAfterGenCriterion,
MaxGenerationParallelism,
MaxTrials,
MinTrials,
Expand Down Expand Up @@ -149,6 +150,14 @@ def setUp(self) -> None:
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
)
]
self.single_running_trial_criterion = [
MaxTrials(
threshold=1,
transition_to="gpei",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.RUNNING],
)
]
self.sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs=self.step_model_kwargs,
Expand Down Expand Up @@ -1388,6 +1397,110 @@ def test_generation_strategy_eq_no_print(self) -> None:
)
self.assertEqual(gs1, gs2)

def test_node_gs_with_auto_transitions(self) -> None:
"""Test that node-based generation strategies which leverage
AutoTransitionAfterGen criterion correctly transition and create trials.
"""
gs = GenerationStrategy(
nodes=[
# first node should be our exploration node and only grs from this node
# should be on the first trial
GenerationNode(
node_name="sobol",
model_specs=[self.sobol_model_spec],
transition_criteria=self.single_running_trial_criterion,
),
# node 2,3,4 will be out iteration nodes, and grs from all 3 nodes
# should be used to make the subsequent trials
GenerationNode(
node_name="gpei",
model_specs=[self.gpei_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="sobol_2")
],
),
GenerationNode(
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="gpei_2")
],
),
GenerationNode(
node_name="gpei_2",
model_specs=[self.gpei_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="gpei")
],
),
],
)
exp = get_branin_experiment()

self.assertEqual(gs.current_node_name, "sobol")
exp.new_trial(generator_run=gs.gen(exp)).run()
# while here, test the last generator run property on node
self.assertEqual(gs.current_node.node_that_generated_last_gr, "sobol")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "sobol_2")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei_2")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei")

# TODO: @mgarrard modify below test when gen handles multiple nodes
gs_2 = GenerationStrategy(
nodes=[
GenerationNode(
node_name="sobol",
model_specs=[self.sobol_model_spec],
transition_criteria=self.single_running_trial_criterion,
),
GenerationNode(
node_name="gpei",
model_specs=[self.gpei_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="sobol_2")
],
),
GenerationNode(
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="gpei_2")
],
),
GenerationNode(
node_name="gpei_2",
model_specs=[self.gpei_model_spec],
transition_criteria=[
MaxTrials(
threshold=2,
transition_to="sobol_3",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.RUNNING],
use_all_trials_in_exp=True,
)
],
),
GenerationNode(
node_name="sobol_3",
model_specs=[self.sobol_model_spec],
),
],
)
self.assertEqual(gs_2.current_node_name, "sobol")
exp.new_trial(generator_run=gs_2.gen(exp)).run()
self.assertEqual(gs_2.current_node_name, "gpei")
gs_2.gen(exp)
self.assertEqual(gs_2.current_node_name, "sobol_2")
gs_2.gen(exp) # noqa
self.assertEqual(gs_2.current_node_name, "gpei_2")
exp.new_trial(generator_run=gs_2.gen(exp)).run()
self.assertEqual(gs_2.current_node_name, "sobol_3")

# ------------- Testing helpers (put tests above this line) -------------

def _run_GS_for_N_rounds(
Expand Down
52 changes: 49 additions & 3 deletions ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
import pandas as pd
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.generation_strategy import (
GenerationNode,
GenerationStep,
GenerationStrategy,
)
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import (
AutoTransitionAfterGenCriterion,
MaxGenerationParallelism,
MaxTrials,
MinimumPreferenceOccurances,
Expand All @@ -30,6 +36,15 @@


class TestTransitionCriterion(TestCase):
def setUp(self) -> None:
super().setUp()
self.sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs={"init_position": 3},
model_gen_kwargs={"some_gen_kwarg": "some_value"},
)
self.branin_experiment = get_branin_experiment()

def test_minimum_preference_criterion(self) -> None:
"""Tests the minimum preference criterion subcalss of TransitionCriterion."""
criterion = MinimumPreferenceOccurances(metric_name="m1", threshold=3)
Expand Down Expand Up @@ -154,7 +169,7 @@ def test_default_step_criterion_setup(self) -> None:

def test_min_trials_is_met(self) -> None:
"""Test that the is_met method in MinTrials works"""
experiment = get_branin_experiment()
experiment = self.branin_experiment
gs = GenerationStrategy(
name="SOBOL::default",
steps=[
Expand Down Expand Up @@ -216,9 +231,33 @@ def test_min_trials_is_met(self) -> None:
trial._status = TrialStatus.EARLY_STOPPED
self.assertTrue(min_criterion.is_met(experiment, gs._steps[0].trials_from_node))

def test_auto_transition(self) -> None:
"""Very simple test to validate AutoTransitionAfterGenCriterion"""
experiment = self.branin_experiment
gs = GenerationStrategy(
name="test",
nodes=[
GenerationNode(
node_name="sobol_1",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="sobol_2")
],
),
GenerationNode(
node_name="sobol_2", model_specs=[self.sobol_model_spec]
),
],
)
gs.experiment = experiment
self.assertEqual(gs.current_node_name, "sobol_1")
gs.gen(experiment=experiment)
gs.gen(experiment=experiment)
self.assertEqual(gs.current_node_name, "sobol_2")

def test_max_trials_is_met(self) -> None:
"""Test that the is_met method in MaxTrials works"""
experiment = get_branin_experiment()
experiment = self.branin_experiment
gs = GenerationStrategy(
name="SOBOL::default",
steps=[
Expand Down Expand Up @@ -446,3 +485,10 @@ def test_repr(self) -> None:
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False})",
)
auto_transition = AutoTransitionAfterGenCriterion(
transition_to="GenerationStep_2"
)
self.assertEqual(
str(auto_transition),
"AutoTransitionAfterGenCriterion({'transition_to': 'GenerationStep_2'})",
)
52 changes: 49 additions & 3 deletions ax/modelbridge/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def transition_to(self) -> Optional[str]:

@abstractmethod
def is_met(
self, experiment: Experiment, trials_from_node: Optional[Set[int]] = None
self,
experiment: Experiment,
trials_from_node: Optional[Set[int]] = None,
node_that_generated_last_gr: Optional[str] = None,
curr_node_name: Optional[str] = None,
) -> bool:
"""If the criterion of this TransitionCriterion is met, returns True."""
pass
Expand Down Expand Up @@ -94,6 +98,38 @@ def _unique_id(self) -> str:
return str(self)


class AutoTransitionAfterGenCriterion(TransitionCriterion):
"""A class to designate automatic transition from one GenerationNode to another.

Args:
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to next.
"""

def __init__(self, transition_to: str) -> None:
super().__init__(transition_to=transition_to)

def is_met(
self,
experiment: Experiment,
trials_from_node: Optional[Set[int]] = None,
node_that_generated_last_gr: Optional[str] = None,
curr_node_name: Optional[str] = None,
) -> bool:
"""Return true as soon as any trial is generated by this GenerationNode."""
return node_that_generated_last_gr == curr_node_name

def block_continued_generation_error(
self,
node_name: Optional[str],
model_name: Optional[str],
experiment: Optional[Experiment],
trials_from_node: Optional[Set[int]] = None,
) -> None:
"""Error to be raised if the `block_gen_if_met` flag is set to True."""
pass


class TrialBasedCriterion(TransitionCriterion):
"""Common class for transition criterion that are based on trial information.

Expand Down Expand Up @@ -219,6 +255,8 @@ def is_met(
experiment: Experiment,
trials_from_node: Optional[Set[int]] = None,
block_continued_generation: Optional[bool] = False,
node_that_generated_last_gr: Optional[str] = None,
curr_node_name: Optional[str] = None,
) -> bool:
"""Returns if this criterion has been met given its constraints.
Args:
Expand Down Expand Up @@ -490,7 +528,11 @@ def __init__(
)

def is_met(
self, experiment: Experiment, trials_from_node: Optional[Set[int]] = None
self,
experiment: Experiment,
trials_from_node: Optional[Set[int]] = None,
node_that_generated_last_gr: Optional[str] = None,
curr_node_name: Optional[str] = None,
) -> bool:
# TODO: @mgarrard replace fetch_data with lookup_data
data = experiment.fetch_data(metrics=[experiment.metrics[self.metric_name]])
Expand Down Expand Up @@ -527,7 +569,11 @@ def __init__(
super().__init__(transition_to=transition_to)

def is_met(
self, experiment: Experiment, trials_from_node: Optional[Set[int]] = None
self,
experiment: Experiment,
trials_from_node: Optional[Set[int]] = None,
node_that_generated_last_gr: Optional[str] = None,
curr_node_name: Optional[str] = None,
) -> bool:
return len(experiment.trial_indices_by_status[self.status]) >= self.threshold

Expand Down