Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626499818
  • Loading branch information
Vertex Vision Eng authored and dstnluong-google committed Apr 25, 2024
1 parent dcd2624 commit c2608f8
Show file tree
Hide file tree
Showing 75 changed files with 18,433 additions and 10,160 deletions.
@@ -0,0 +1,162 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "3c19749c-4780-47e7-9da1-c4a86aa227dc",
"metadata": {},
"source": [
"# Local docker run for JAX VIT training\n",
"\n",
"This notebook shows local docker run for JAX VIT training.\n",
"This notebook uses a workbench with TensorFlow 2.11 and 8 v100 GPUs.\n",
"You also need to upload the 'train_vit_gpu.Dockerfile' to the home directory for workbench."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a3df6e8a-68c7-40a2-9b24-0ed7acdeffc3",
"metadata": {},
"outputs": [],
"source": [
"# Build training docker\n",
"\n",
"project=\"cloud-nas-260507\"\n",
"image_tag=\"jax-vit-train-gpu-lavrai-test:latest\"\n",
"train_docker_uri=\"gcr.io/{}/{}\".format(project, image_tag)\n",
"\n",
"!docker build -f train_vit_gpu.Dockerfile . -t {image_tag}\n",
"\n",
"!docker tag {image_tag} {train_docker_uri}\n",
"\n",
"!docker push {train_docker_uri}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10978e1a-8cf2-4a07-8b0c-f0e45a88bf53",
"metadata": {},
"outputs": [],
"source": [
"# Docker arguments.\n",
"workdir='tmp'\n",
"docker_args_list=[\n",
" '--config', 'vit_jax/configs/augreg.py:R_Ti_16',\n",
" '--config.dataset', 'tf_flowers',\n",
" '--config.pp.train', 'train[:90%]',\n",
" '--config.pp.test', 'train[90%:]',\n",
" '--config.batch_eval', '120',\n",
" '--config.base_lr', '0.01',\n",
" '--config.shuffle_buffer', '1000',\n",
" '--config.total_steps', '100',\n",
" '--config.warmup_steps', '10',\n",
" '--config.accum_steps', '0', # Not needed with R+Ti/16 model.\n",
" '--config.pp.crop', '224',\n",
" '--workdir', f'{workdir}',\n",
" ]\n",
"print(docker_args_list)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b36477a6-f083-4098-8745-aebcc6e74098",
"metadata": {},
"outputs": [],
"source": [
"# Utility functions.\n",
"\n",
"import io\n",
"import subprocess\n",
"import sys\n",
"\n",
"def run_command_with_stdout(cmd, job_log_file=None, error_message=\"\"):\n",
" \"\"\"Runs the command and stream the command outputs.\"\"\"\n",
" if job_log_file is None:\n",
" job_log_file = sys.stdout\n",
" buf = io.StringIO()\n",
" ret_code = None\n",
"\n",
" with subprocess.Popen(\n",
" cmd,\n",
" stdin=subprocess.PIPE,\n",
" stdout=subprocess.PIPE,\n",
" stderr=subprocess.STDOUT,\n",
" universal_newlines=False,\n",
" ) as p:\n",
" out = io.TextIOWrapper(p.stdout, newline=\"\")\n",
"\n",
" for line in out:\n",
" buf.write(line)\n",
" job_log_file.write(line)\n",
" job_log_file.flush()\n",
"\n",
" # flush to force the contents to display.\n",
" job_log_file.flush()\n",
"\n",
" while p.poll() is None:\n",
" # Process hasn't exited yet, let's wait some\n",
" time.sleep(0.5)\n",
"\n",
" ret_code = p.returncode\n",
" p.stdout.close()\n",
"\n",
" if ret_code:\n",
" raise RuntimeError(\n",
" \"Error: {} with return code {}\".format(error_message, ret_code)\n",
" )\n",
" return buf.getvalue(), ret_code\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7d559395-c1d5-4ccd-8066-054eac3ad2d2",
"metadata": {},
"outputs": [],
"source": [
"# Run local training.\n",
"cmd = ([\"nvidia-docker\", \"run\"] + [\"-t\", train_docker_uri] + docker_args_list)\n",
"run_command_with_stdout(\n",
" cmd, error_message=\"Failed to run docker locally\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88367395-5608-4262-82d9-e94c55f15903",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"environment": {
"kernel": "python3",
"name": "tf2-gpu.2-11.m108",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-11:m108"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit c2608f8

Please sign in to comment.