Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load mistral and mixtral models from GCS #2893

Merged
merged 5 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -95,7 +95,8 @@
"# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).\n",
"\n",
"# @markdown 2. [Optional] [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. \"us\") is not considered a match for a single region covered by the multi-region range (eg. \"us-central1\"). If not set, a unique GCS bucket will be created instead.\n",
"# Import the necessary packages\n",
"\n",
"# Import the necessary packages.\n",
"import os\n",
"from datetime import datetime\n",
"from typing import Tuple\n",
Expand All @@ -116,16 +117,8 @@
"# A unique GCS bucket will be created for the purpose of this notebook. If you\n",
"# prefer using your own GCS bucket, change the value yourself below.\n",
"now = datetime.now().strftime(\"%Y%m%d%H%M%S\")\n",
"BUCKET_URI = \"gs://\" # @param {type:\"string\"}\n",
"BUCKET_URI = \"gs://\" # @param {type: \"string\"}\n",
"assert BUCKET_URI.startswith(\"gs://\"), \"BUCKET_URI must start with `gs://`.\"\n",
"if BUCKET_URI is None or BUCKET_URI.strip() == \"\" or BUCKET_URI == \"gs://\":\n",
" BUCKET_URI = f\"gs://{PROJECT_ID}-tmp-{now}\"\n",
"\n",
"STAGING_BUCKET = os.path.join(BUCKET_URI, \"temporal\")\n",
"\n",
"# Initialize Vertex AI API.\n",
"print(\"Initializing Vertex AI API.\")\n",
"aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)\n",
"\n",
"# Gets the default BUCKET_URI and SERVICE_ACCOUNT if they were not specified by the user.\n",
"\n",
Expand All @@ -136,15 +129,18 @@
"print(\"Using this default Service Account:\", SERVICE_ACCOUNT)\n",
"\n",
"# Create a unique GCS bucket for this notebook, if not specified by the user.\n",
"assert BUCKET_URI.startswith(\"gs://\"), \"BUCKET_URI must start with `gs://`.\"\n",
"if BUCKET_URI is None or BUCKET_URI.strip() == \"\" or BUCKET_URI == \"gs://\":\n",
" # Create a unique GCS bucket for this notebook, if not specified by the user\n",
" BUCKET_URI = f\"gs://{PROJECT_ID}-tmp-{now}\"\n",
" ! gsutil mb -l {REGION} {BUCKET_URI}\n",
"else:\n",
" shell_output = ! gsutil ls -Lb {BUCKET_URI} | grep \"Location constraint:\" | sed \"s/Location constraint://\"\n",
" BUCKET_NAME = \"/\".join(BUCKET_URI.split(\"/\")[:3])\n",
" shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep \"Location constraint:\" | sed \"s/Location constraint://\"\n",
" bucket_region = shell_output[0].strip().lower()\n",
" if bucket_region != REGION:\n",
" raise ValueError(\n",
" \"Bucket region %s is different from notebook region %s\"\n",
" \"Bucket region '%s' is different from notebook region '%s'\"\n",
" % (bucket_region, REGION)\n",
" )\n",
"\n",
Expand All @@ -156,6 +152,9 @@
"\n",
"! gcloud config set project $PROJECT_ID\n",
"\n",
"# Initialize Vertex AI API.\n",
"aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)\n",
"\n",
"# The pre-built serving docker images with vLLM\n",
"VLLM_DOCKER_URI = \"us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20240313_0916_RC00\"\n",
"\n",
Expand Down Expand Up @@ -221,7 +220,7 @@
" vllm_args = [\n",
" \"--host=0.0.0.0\",\n",
" \"--port=7080\",\n",
" f\"--model={model_id}\",\n",
" f\"--model=gs://vertex-model-garden-public-us/{model_id}\",\n",
" f\"--tensor-parallel-size={accelerator_count}\",\n",
" \"--swap-space=16\",\n",
" f\"--dtype={dtype}\",\n",
Expand Down Expand Up @@ -406,7 +405,7 @@
"\n",
"delete_bucket = False # @param {type:\"boolean\"}\n",
"if delete_bucket:\n",
" ! gsutil -m rm $BUCKET_URI"
" ! gsutil -m rm -r $BUCKET_URI"
]
}
],
Expand Down
Expand Up @@ -192,7 +192,7 @@
"\n",
" model = aiplatform.Model.upload(\n",
" display_name=model_name,\n",
" serving_container_image_uri=vllm_docker_uri,\n",
" serving_container_image_uri=VLLM_DOCKER_URI,\n",
" serving_container_command=[\"python\", \"-m\", \"vllm.entrypoints.api_server\"],\n",
" serving_container_args=vllm_args,\n",
" serving_container_ports=[7080],\n",
Expand Down Expand Up @@ -259,11 +259,11 @@
"\n",
"# Huggingface dataset name or gs:// URI to a custom JSONL dataset.\n",
"base_model_id = \"mistralai/Mistral-7B-v0.1\"\n",
"\n",
"gcs_model_id = f\"gs://vertex-model-garden-public-us/{base_model_id}\"\n",
"dataset_name = \"fredmo/vertexai-qna-500\" # @param {type:\"string\"}\n",
"\n",
"# Optional. Template name or gs:// URI to a custom template.\n",
"template = \"vertex_sample\" # @param {type:\"string\"}\n",
"template = \"vertex_sample\" # @param {type:\"string\"}\n",
"\n",
"# Number of training steps.\n",
"max_steps = 10 # @param {type:\"integer\"}\n",
Expand All @@ -274,7 +274,7 @@
"lora_dropout = 0.1 # @param{type:\"number\"}\n",
"\n",
"# Learning rate.\n",
"learning_rate = 0.01 # @param{type:\"number\"}\n",
"learning_rate = 0.0001 # @param{type:\"number\"}\n",
"\n",
"# Precision mode for finetuning.\n",
"finetuning_precision_mode = \"float16\"\n",
Expand All @@ -288,8 +288,7 @@
"elif accelerator_type == \"NVIDIA_L4\":\n",
" machine_type = \"g2-standard-8\"\n",
" accelerator_count = 1\n",
"elif\n",
" accelerator_type == \"NVIDIA_TESLA_A100\":\n",
"elif accelerator_type == \"NVIDIA_TESLA_A100\":\n",
" machine_type = \"a2-highgpu-1g\"\n",
" accelerator_count = 1\n",
"else:\n",
Expand Down Expand Up @@ -326,7 +325,7 @@
"train_job.run(\n",
" args=[\n",
" \"--task=causal-language-modeling-lora\",\n",
" f\"--pretrained_model_id={base_model_id}\",\n",
" f\"--pretrained_model_id={gcs_model_id}\",\n",
" f\"--dataset_name={dataset_name}\",\n",
" f\"--output_dir={lora_output_dir}\",\n",
" f\"--merge_base_and_lora_output_dir={merged_model_output_dir}\",\n",
Expand Down Expand Up @@ -418,8 +417,8 @@
"# )\n",
"# endpoint = aiplatform.Endpoint(aip_endpoint_name)\n",
"\n",
"prompt = \"What is Model Garden?\" # @param {type: \"string\"}\n",
"max_tokens = 50 # @param {type:\"integer\"}\n",
"prompt = \"What is Vertex AI?\" # @param {type: \"string\"}\n",
"max_tokens = 100 # @param {type:\"integer\"}\n",
"temperature = 1.0 # @param {type:\"number\"}\n",
"top_p = 1.0 # @param {type:\"number\"}\n",
"top_k = 1 # @param {type:\"integer\"}\n",
Expand All @@ -433,7 +432,9 @@
"}\n",
"\n",
"response = endpoint.predict(instances=[instance])\n",
"print(response.predictions[0])"
"\n",
"for prediction in response.predictions:\n",
" print(prediction)"
]
},
{
Expand All @@ -454,7 +455,7 @@
"\n",
"delete_bucket = False # @param {type:\"boolean\"}\n",
"if delete_bucket:\n",
" ! gsutil -m rm $BUCKET_URI"
" ! gsutil -m rm -r $BUCKET_URI"
]
}
],
Expand Down
Expand Up @@ -217,10 +217,17 @@
" if accelerator_type in [\"NVIDIA_TESLA_T4\", \"NVIDIA_TESLA_V100\"]:\n",
" dtype = \"float16\"\n",
"\n",
" if \"asia\" in REGION:\n",
" region_suffix = \"asia\"\n",
" elif \"europe\" in REGION:\n",
" region_suffix = \"eu\"\n",
" else:\n",
" region_suffix = \"us\"\n",
"\n",
" vllm_args = [\n",
" \"--host=0.0.0.0\",\n",
" \"--port=7080\",\n",
" f\"--model={model_id}\",\n",
" f\"--model=gs://vertex-model-garden-public-{region_suffix}/{model_id}\",\n",
" f\"--tensor-parallel-size={accelerator_count}\",\n",
" \"--swap-space=16\",\n",
" f\"--dtype={dtype}\",\n",
Expand Down Expand Up @@ -402,7 +409,7 @@
"\n",
"delete_bucket = False # @param {type:\"boolean\"}\n",
"if delete_bucket:\n",
" ! gsutil -m rm $BUCKET_URI"
" ! gsutil -m rm -r $BUCKET_URI"
]
}
],
Expand Down
Expand Up @@ -261,7 +261,13 @@
"\n",
"# Huggingface dataset name or gs:// URI to a custom JSONL dataset.\n",
"base_model_id = \"mistralai/Mixtral-8x7B\"\n",
"\n",
"if \"asia\" in REGION:\n",
" region_suffix = \"asia\"\n",
"elif \"europe\" in REGION:\n",
" region_suffix = \"eu\"\n",
"else:\n",
" region_suffix = \"us\"\n",
"gcs_model_id = f\"gs://vertex-model-garden-public-{region_suffix}/{base_model_id}\"\n",
"dataset_name = \"fredmo/vertexai-qna-500\" # @param {type:\"string\"}\n",
"\n",
"# Optional. Template name or gs:// URI to a custom template.\n",
Expand Down Expand Up @@ -326,7 +332,7 @@
"train_job.run(\n",
" args=[\n",
" \"--task=causal-language-modeling-lora\",\n",
" f\"--pretrained_model_id={base_model_id}\",\n",
" f\"--pretrained_model_id={gcs_model_id}\",\n",
" f\"--dataset_name={dataset_name}\",\n",
" f\"--output_dir={lora_output_dir}\",\n",
" f\"--merge_base_and_lora_output_dir={merged_model_output_dir}\",\n",
Expand Down