Skip to content

Commit

Permalink
Add AutoTransitionCriterion (facebook#2409)
Browse files Browse the repository at this point in the history
Summary:

In order to support multiple models in a single gen call we need to support the `Auto` Transition class. This criterion will automatically move to the next node once anything has been generated from the current node

Reviewed By: saitcakmak

Differential Revision: D56360662
  • Loading branch information
mgarrard authored and facebook-github-bot committed May 9, 2024
1 parent 6676720 commit bd9bb86
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 7 deletions.
21 changes: 20 additions & 1 deletion ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _pick_fitted_model_to_gen_from(self) -> ModelSpec:
# ------------------------- Trial logic helpers. -------------------------
@property
def trials_from_node(self) -> Set[int]:
"""Returns a dictionary mapping a GenerationNode to the trials it generated.
"""Returns a set mapping a GenerationNode to the trials it generated.
Returns:
Set[int]: A set containing all the trials indices generated by this node.
Expand All @@ -384,6 +384,19 @@ def trials_from_node(self) -> Set[int]:
trials_from_node.add(trial.index)
return trials_from_node

@property
def node_that_generated_last_gr(self) -> Optional[str]:
"""Returns the name of the node that generated the last generator run.
Returns:
str: The name of the node that generated the last generator run.
"""
return (
self.generation_strategy.last_generator_run._generation_node_name
if self.generation_strategy.last_generator_run
else None
)

def should_transition_to_next_node(
self, raise_data_required_error: bool = True
) -> Tuple[bool, Optional[str]]:
Expand All @@ -409,6 +422,12 @@ def should_transition_to_next_node(
tc.is_met(
experiment=self.experiment,
trials_from_node=self.trials_from_node,
node_name=self.node_name,
node_that_generated_last_gr=(
self.generation_strategy.last_generator_run._generation_node_name
if self.generation_strategy.last_generator_run is not None
else None
),
)
for tc in transition_blocking
)
Expand Down
113 changes: 113 additions & 0 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.transition_criterion import (
AutoTransitionAfterGenCriterion,
MaxGenerationParallelism,
MaxTrials,
MinTrials,
Expand Down Expand Up @@ -149,6 +150,14 @@ def setUp(self) -> None:
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
)
]
self.single_running_trial_criterion = [
MaxTrials(
threshold=1,
transition_to="gpei",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.RUNNING],
)
]
self.sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs=self.step_model_kwargs,
Expand Down Expand Up @@ -1388,6 +1397,110 @@ def test_generation_strategy_eq_no_print(self) -> None:
)
self.assertEqual(gs1, gs2)

def test_node_gs_with_auto_transitions(self) -> None:
"""Test that node-based generation strategies which leverage
AutoTransitionAfterGen criterion correctly transition and create trials.
"""
gs = GenerationStrategy(
nodes=[
# first node should be our exploration node and only grs from this node
# should be on the first trial
GenerationNode(
node_name="sobol",
model_specs=[self.sobol_model_spec],
transition_criteria=self.single_running_trial_criterion,
),
# node 2,3,4 will be out iteration nodes, and grs from all 3 nodes
# should be used to make the subsequent trials
GenerationNode(
node_name="gpei",
model_specs=[self.gpei_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="sobol_2")
],
),
GenerationNode(
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="gpei_2")
],
),
GenerationNode(
node_name="gpei_2",
model_specs=[self.gpei_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="gpei")
],
),
],
)
exp = get_branin_experiment()

self.assertEqual(gs.current_node_name, "sobol")
exp.new_trial(generator_run=gs.gen(exp)).run()
# while here, test the last generator run property on node
self.assertEqual(gs.current_node.node_that_generated_last_gr, "sobol")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "sobol_2")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei_2")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei")

# TODO: @mgarrard modify below test when gen handles multiple nodes
gs_2 = GenerationStrategy(
nodes=[
GenerationNode(
node_name="sobol",
model_specs=[self.sobol_model_spec],
transition_criteria=self.single_running_trial_criterion,
),
GenerationNode(
node_name="gpei",
model_specs=[self.gpei_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="sobol_2")
],
),
GenerationNode(
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="gpei_2")
],
),
GenerationNode(
node_name="gpei_2",
model_specs=[self.gpei_model_spec],
transition_criteria=[
MaxTrials(
threshold=2,
transition_to="sobol_3",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.RUNNING],
use_all_trials_in_exp=True,
)
],
),
GenerationNode(
node_name="sobol_3",
model_specs=[self.sobol_model_spec],
),
],
)
self.assertEqual(gs_2.current_node_name, "sobol")
exp.new_trial(generator_run=gs_2.gen(exp)).run()
self.assertEqual(gs_2.current_node_name, "gpei")
gs_2.gen(exp)
self.assertEqual(gs_2.current_node_name, "sobol_2")
gs_2.gen(exp) # noqa
self.assertEqual(gs_2.current_node_name, "gpei_2")
exp.new_trial(generator_run=gs_2.gen(exp)).run()
self.assertEqual(gs_2.current_node_name, "sobol_3")

# ------------- Testing helpers (put tests above this line) -------------

def _run_GS_for_N_rounds(
Expand Down
52 changes: 49 additions & 3 deletions ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
import pandas as pd
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.generation_strategy import (
GenerationNode,
GenerationStep,
GenerationStrategy,
)
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import (
AutoTransitionAfterGenCriterion,
MaxGenerationParallelism,
MaxTrials,
MinimumPreferenceOccurances,
Expand All @@ -30,6 +36,15 @@


class TestTransitionCriterion(TestCase):
def setUp(self) -> None:
super().setUp()
self.sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs={"init_position": 3},
model_gen_kwargs={"some_gen_kwarg": "some_value"},
)
self.branin_experiment = get_branin_experiment()

def test_minimum_preference_criterion(self) -> None:
"""Tests the minimum preference criterion subcalss of TransitionCriterion."""
criterion = MinimumPreferenceOccurances(metric_name="m1", threshold=3)
Expand Down Expand Up @@ -154,7 +169,7 @@ def test_default_step_criterion_setup(self) -> None:

def test_min_trials_is_met(self) -> None:
"""Test that the is_met method in MinTrials works"""
experiment = get_branin_experiment()
experiment = self.branin_experiment
gs = GenerationStrategy(
name="SOBOL::default",
steps=[
Expand Down Expand Up @@ -216,9 +231,33 @@ def test_min_trials_is_met(self) -> None:
trial._status = TrialStatus.EARLY_STOPPED
self.assertTrue(min_criterion.is_met(experiment, gs._steps[0].trials_from_node))

def test_auto_transition(self) -> None:
"""Very simple test to validate AutoTransitionAfterGenCriterion"""
experiment = self.branin_experiment
gs = GenerationStrategy(
name="test",
nodes=[
GenerationNode(
node_name="sobol_1",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="sobol_2")
],
),
GenerationNode(
node_name="sobol_2", model_specs=[self.sobol_model_spec]
),
],
)
gs.experiment = experiment
self.assertEqual(gs.current_node_name, "sobol_1")
gs.gen(experiment=experiment)
gs.gen(experiment=experiment)
self.assertEqual(gs.current_node_name, "sobol_2")

def test_max_trials_is_met(self) -> None:
"""Test that the is_met method in MaxTrials works"""
experiment = get_branin_experiment()
experiment = self.branin_experiment
gs = GenerationStrategy(
name="SOBOL::default",
steps=[
Expand Down Expand Up @@ -446,3 +485,10 @@ def test_repr(self) -> None:
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False})",
)
auto_transition = AutoTransitionAfterGenCriterion(
transition_to="GenerationStep_2"
)
self.assertEqual(
str(auto_transition),
"AutoTransitionAfterGenCriterion({'transition_to': 'GenerationStep_2'})",
)
52 changes: 49 additions & 3 deletions ax/modelbridge/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def transition_to(self) -> Optional[str]:

@abstractmethod
def is_met(
self, experiment: Experiment, trials_from_node: Optional[Set[int]] = None
self,
experiment: Experiment,
trials_from_node: Optional[Set[int]] = None,
node_that_generated_last_gr: Optional[str] = None,
node_name: Optional[str] = None,
) -> bool:
"""If the criterion of this TransitionCriterion is met, returns True."""
pass
Expand Down Expand Up @@ -94,6 +98,38 @@ def _unique_id(self) -> str:
return str(self)


class AutoTransitionAfterGenCriterion(TransitionCriterion):
"""A class to designate automatic transition from one GenerationNode to another.
Args:
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to next.
"""

def __init__(self, transition_to: str) -> None:
super().__init__(transition_to=transition_to)

def is_met(
self,
experiment: Experiment,
trials_from_node: Optional[Set[int]] = None,
node_that_generated_last_gr: Optional[str] = None,
node_name: Optional[str] = None,
) -> bool:
"""Return true as soon as any trial is generated by this GenerationNode."""
return node_that_generated_last_gr == node_name

def block_continued_generation_error(
self,
node_name: Optional[str],
model_name: Optional[str],
experiment: Optional[Experiment],
trials_from_node: Optional[Set[int]] = None,
) -> None:
"""Error to be raised if the `block_gen_if_met` flag is set to True."""
pass


class TrialBasedCriterion(TransitionCriterion):
"""Common class for transition criterion that are based on trial information.
Expand Down Expand Up @@ -219,6 +255,8 @@ def is_met(
experiment: Experiment,
trials_from_node: Optional[Set[int]] = None,
block_continued_generation: Optional[bool] = False,
node_that_generated_last_gr: Optional[str] = None,
node_name: Optional[str] = None,
) -> bool:
"""Returns if this criterion has been met given its constraints.
Args:
Expand Down Expand Up @@ -490,7 +528,11 @@ def __init__(
)

def is_met(
self, experiment: Experiment, trials_from_node: Optional[Set[int]] = None
self,
experiment: Experiment,
trials_from_node: Optional[Set[int]] = None,
node_that_generated_last_gr: Optional[str] = None,
node_name: Optional[str] = None,
) -> bool:
# TODO: @mgarrard replace fetch_data with lookup_data
data = experiment.fetch_data(metrics=[experiment.metrics[self.metric_name]])
Expand Down Expand Up @@ -527,7 +569,11 @@ def __init__(
super().__init__(transition_to=transition_to)

def is_met(
self, experiment: Experiment, trials_from_node: Optional[Set[int]] = None
self,
experiment: Experiment,
trials_from_node: Optional[Set[int]] = None,
node_that_generated_last_gr: Optional[str] = None,
node_name: Optional[str] = None,
) -> bool:
return len(experiment.trial_indices_by_status[self.status]) >= self.threshold

Expand Down

0 comments on commit bd9bb86

Please sign in to comment.