Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SimplifyParameterConstraints #2326

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
69 changes: 69 additions & 0 deletions ax/modelbridge/transforms/simplify_parameter_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
import math
from typing import List, TYPE_CHECKING

from ax.core.parameter import FixedParameter, ParameterType, RangeParameter
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.utils.common.typeutils import checked_cast

if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401


class SimplifyParameterConstraints(Transform):
"""Convert parameter constraints on one parameter to an updated bound.

This transform converts parameter constraints on only one parameter into an updated
upper or lower bound. Note that this transform will convert parameters that can only
take on one value into a `FixedParameter`. Make sure this transform is applied
before `RemoveFixed` if you want to remove all fixed parameters.
"""

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
# keeps track of the constraints that cannot be converted to bounds
nontrivial_constraints: List[ParameterConstraint] = []
for pc in search_space.parameter_constraints:
if len(pc.constraint_dict) == 1:
# This can be turned into an updated bound since only one variable is
# involved in the constraint.
[(p_name, weight)] = pc.constraint_dict.items()
# NOTE: We only allow parameter constraints on range parameters
p = checked_cast(RangeParameter, search_space.parameters[p_name])
lb, ub = p.lower, p.upper
if weight == 0 and pc.bound < 0: # Cannot be satisfied
raise ValueError(
"Parameter constraint cannot be satisfied since the weight "
"is zero and the bound is negative."
)
elif weight == 0: # Constraint is always satisfied
continue
elif weight > 0: # New upper bound
ub = float(pc.bound) / float(weight)
if p.parameter_type == ParameterType.INT:
ub = math.floor(ub) # Round down
else: # New lower bound
lb = float(pc.bound) / float(weight)
if p.parameter_type == ParameterType.INT:
lb = math.ceil(lb) # Round up

if lb == ub: # Need to turn this into a fixed parameter
search_space.parameters[p_name] = FixedParameter(
name=p_name, parameter_type=p.parameter_type, value=lb
)
elif weight > 0:
p._upper = ub
else:
p._lower = lb
else:
nontrivial_constraints.append(pc)
search_space.set_parameter_constraints(nontrivial_constraints)
return search_space
122 changes: 122 additions & 0 deletions ax/modelbridge/transforms/tests/test_simplify_parameter_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
from copy import deepcopy
from typing import List

from ax.core.observation import ObservationFeatures
from ax.core.parameter import (
ChoiceParameter,
FixedParameter,
Parameter,
ParameterType,
RangeParameter,
)
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.simplify_parameter_constraints import (
SimplifyParameterConstraints,
)
from ax.utils.common.testutils import TestCase


class SimplifyParameterConstraintsTest(TestCase):
def setUp(self) -> None:
self.parameters: List[Parameter] = [
RangeParameter("x", lower=1, upper=3, parameter_type=ParameterType.FLOAT),
RangeParameter("y", lower=2, upper=5, parameter_type=ParameterType.INT),
ChoiceParameter(
"z", parameter_type=ParameterType.STRING, values=["a", "b", "c"]
),
]
self.observation_features = [
ObservationFeatures(parameters={"x": 2, "y": 2, "z": "b"})
]

def test_transform_no_constraints(self) -> None:
t = SimplifyParameterConstraints()
ss = SearchSpace(parameters=self.parameters)
ss_transformed = t.transform_search_space(search_space=ss)
self.assertEqual(ss, ss_transformed)
self.assertEqual(
self.observation_features,
t.transform_observation_features(self.observation_features),
)

def test_transform_weight_zero(self) -> None:
t = SimplifyParameterConstraints()
ss = SearchSpace(
parameters=self.parameters,
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": 0}, bound=1)
],
)
ss_transformed = t.transform_search_space(search_space=deepcopy(ss))
self.assertEqual(ss_transformed.parameter_constraints, [])
self.assertEqual(ss.parameters, ss_transformed.parameters)
ss_raises = SearchSpace(
parameters=self.parameters,
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": 0}, bound=-1)
],
)
with self.assertRaisesRegex(
ValueError, "Parameter constraint cannot be satisfied since the weight"
):
ss_transformed = t.transform_search_space(search_space=deepcopy(ss_raises))

def test_transform_search_space(self) -> None:
t = SimplifyParameterConstraints()
ss = SearchSpace(
parameters=self.parameters,
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": 1}, bound=2), # x <= 2
ParameterConstraint(constraint_dict={"y": -1}, bound=-4), # y => 4
],
)
ss_transformed = t.transform_search_space(search_space=deepcopy(ss))
self.assertEqual(
{
**ss.parameters,
"x": RangeParameter(
"x", parameter_type=ParameterType.FLOAT, lower=1, upper=2
),
"y": RangeParameter(
"y", parameter_type=ParameterType.INT, lower=4, upper=5
),
},
ss_transformed.parameters,
)
self.assertEqual(ss_transformed.parameter_constraints, [])
self.assertEqual( # No-op
self.observation_features,
t.transform_observation_features(self.observation_features),
)

def test_transform_to_fixed(self) -> None:
t = SimplifyParameterConstraints()
ss = SearchSpace(
parameters=self.parameters,
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": 1}, bound=1), # x == 1
ParameterConstraint(constraint_dict={"y": -1}, bound=-5), # y == 5
],
)
ss_transformed = t.transform_search_space(search_space=deepcopy(ss))
self.assertEqual(
{
**ss.parameters,
"x": FixedParameter("x", parameter_type=ParameterType.FLOAT, value=1),
"y": FixedParameter("y", parameter_type=ParameterType.INT, value=5),
},
ss_transformed.parameters,
)
self.assertEqual(ss_transformed.parameter_constraints, [])
self.assertEqual( # No-op
self.observation_features,
t.transform_observation_features(self.observation_features),
)
4 changes: 4 additions & 0 deletions ax/storage/transform_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
)
from ax.modelbridge.transforms.remove_fixed import RemoveFixed
from ax.modelbridge.transforms.search_space_to_choice import SearchSpaceToChoice
from ax.modelbridge.transforms.simplify_parameter_constraints import (
SimplifyParameterConstraints,
)
from ax.modelbridge.transforms.standardize_y import StandardizeY
from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY
from ax.modelbridge.transforms.task_encode import TaskEncode
Expand Down Expand Up @@ -79,6 +82,7 @@
LogY: 23,
Relativize: 24,
RelativizeWithConstantControl: 25,
SimplifyParameterConstraints: 26,
}


Expand Down
8 changes: 8 additions & 0 deletions sphinx/source/modelbridge.rst
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,14 @@ Transforms
:undoc-members:
:show-inheritance:

`ax.modelbridge.transforms.simplify_parameter_constraints`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: ax.modelbridge.transforms.simplify_parameter_constraints
:members:
:undoc-members:
:show-inheritance:

`ax.modelbridge.transforms.standardize\_y`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down