Skip to content

Commit

Permalink
fix!: make learning rate a one-of
Browse files Browse the repository at this point in the history
docs: A few small updates
feat: Add `learning_rate_multiplier` to tuning `Hyperparameters`

PiperOrigin-RevId: 616144364
  • Loading branch information
Google APIs authored and Copybara-Service committed Mar 15, 2024
1 parent fe20507 commit 074ea98
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 21 deletions.
1 change: 1 addition & 0 deletions google/ai/generativelanguage/v1beta/BUILD.bazel
Expand Up @@ -345,6 +345,7 @@ load(

csharp_proto_library(
name = "generativelanguage_csharp_proto",
extra_opts = [],
deps = [":generativelanguage_proto"],
)

Expand Down
25 changes: 14 additions & 11 deletions google/ai/generativelanguage/v1beta/generative_service.proto
Expand Up @@ -42,6 +42,10 @@ service GenerativeService {
option (google.api.http) = {
post: "/v1beta/{model=models/*}:generateContent"
body: "*"
additional_bindings {
post: "/v1beta/{model=tunedModels/*}:generateContent"
body: "*"
}
};
option (google.api.method_signature) = "model,contents";
}
Expand Down Expand Up @@ -185,18 +189,16 @@ message GenerationConfig {

// Optional. The maximum number of tokens to include in a candidate.
//
// If unset, this will default to output_token_limit specified in the `Model`
// specification.
// Note: The default value varies by model, see the `Model.output_token_limit`
// attribute of the `Model` returned from the `getModel` function.
optional int32 max_output_tokens = 4 [(google.api.field_behavior) = OPTIONAL];

// Optional. Controls the randomness of the output.
//
// Note: The default value varies by model, see the `Model.temperature`
// attribute of the `Model` returned the `getModel` function.
// attribute of the `Model` returned from the `getModel` function.
//
// Values can range from [0.0,1.0],
// inclusive. A value closer to 1.0 will produce responses that are more
// varied and creative, while a value closer to 0.0 will typically result in
// more straightforward responses from the model.
// Values can range from [0.0, infinity).
optional float temperature = 5 [(google.api.field_behavior) = OPTIONAL];

// Optional. The maximum cumulative probability of tokens to consider when
Expand All @@ -210,18 +212,17 @@ message GenerationConfig {
// of tokens based on the cumulative probability.
//
// Note: The default value varies by model, see the `Model.top_p`
// attribute of the `Model` returned the `getModel` function.
// attribute of the `Model` returned from the `getModel` function.
optional float top_p = 6 [(google.api.field_behavior) = OPTIONAL];

// Optional. The maximum number of tokens to consider when sampling.
//
// The model uses combined Top-k and nucleus sampling.
//
// Top-k sampling considers the set of `top_k` most probable tokens.
// Defaults to 40.
//
// Note: The default value varies by model, see the `Model.top_k`
// attribute of the `Model` returned the `getModel` function.
// attribute of the `Model` returned from the `getModel` function.
optional int32 top_k = 7 [(google.api.field_behavior) = OPTIONAL];
}

Expand Down Expand Up @@ -457,7 +458,9 @@ message GenerateAnswerRequest {
// overrides the default settings for each `SafetyCategory` specified in the
// safety_settings. If there is no `SafetySetting` for a given
// `SafetyCategory` provided in the list, the API will use the default safety
// setting for that category.
// setting for that category. Harm categories HARM_CATEGORY_HATE_SPEECH,
// HARM_CATEGORY_SEXUALLY_EXPLICIT, HARM_CATEGORY_DANGEROUS_CONTENT,
// HARM_CATEGORY_HARASSMENT are supported.
repeated SafetySetting safety_settings = 3
[(google.api.field_behavior) = OPTIONAL];

Expand Down
4 changes: 2 additions & 2 deletions google/ai/generativelanguage/v1beta/safety.proto
Expand Up @@ -34,10 +34,10 @@ enum HarmCategory {
// Negative or harmful comments targeting identity and/or protected attribute.
HARM_CATEGORY_DEROGATORY = 1;

// Content that is rude, disrepspectful, or profane.
// Content that is rude, disrespectful, or profane.
HARM_CATEGORY_TOXICITY = 2;

// Describes scenarios depictng violence against an individual or group, or
// Describes scenarios depicting violence against an individual or group, or
// general descriptions of gore.
HARM_CATEGORY_VIOLENCE = 3;

Expand Down
33 changes: 25 additions & 8 deletions google/ai/generativelanguage/v1beta/tuned_model.proto
Expand Up @@ -174,21 +174,38 @@ message TuningTask {
Hyperparameters hyperparameters = 5 [(google.api.field_behavior) = IMMUTABLE];
}

// Hyperparameters controlling the tuning process.
// Hyperparameters controlling the tuning process. Read more at
// https://ai.google.dev/docs/model_tuning_guidance
message Hyperparameters {
// Options for specifying learning rate during tuning.
oneof learning_rate_option {
// Optional. Immutable. The learning rate hyperparameter for tuning.
// If not set, a default of 0.001 or 0.0002 will be calculated based on the
// number of training examples.
float learning_rate = 16 [
(google.api.field_behavior) = IMMUTABLE,
(google.api.field_behavior) = OPTIONAL
];

// Optional. Immutable. The learning rate multiplier is used to calculate a
// final learning_rate based on the default (recommended) value. Actual
// learning rate := learning_rate_multiplier * default learning rate Default
// learning rate is dependent on base model and dataset size. If not set, a
// default of 1.0 will be used.
float learning_rate_multiplier = 17 [
(google.api.field_behavior) = IMMUTABLE,
(google.api.field_behavior) = OPTIONAL
];
}

// Immutable. The number of training epochs. An epoch is one pass through the
// training data. If not set, a default of 10 will be used.
// training data. If not set, a default of 5 will be used.
optional int32 epoch_count = 14 [(google.api.field_behavior) = IMMUTABLE];

// Immutable. The batch size hyperparameter for tuning.
// If not set, a default of 16 or 64 will be used based on the number of
// If not set, a default of 4 or 16 will be used based on the number of
// training examples.
optional int32 batch_size = 15 [(google.api.field_behavior) = IMMUTABLE];

// Immutable. The learning rate hyperparameter for tuning.
// If not set, a default of 0.0002 or 0.002 will be calculated based on the
// number of training examples.
optional float learning_rate = 16 [(google.api.field_behavior) = IMMUTABLE];
}

// Dataset for training or validation.
Expand Down

0 comments on commit 074ea98

Please sign in to comment.