Skip to content

Commit

Permalink
Adds lightning fabric integration (#749)
Browse files Browse the repository at this point in the history
* wip lightning fabric logger

* add fabric tests

* lightning: fix test_lightning_val_updates_to_studio

* lightning: move next_step logic to fabric

* add sync method and use in lightning/fabric callbacks

* fabric: auto-increment step

* add fabric example notebook

* update fabric notebook

* fix mypy errors

* skip fabric tests if not installed

* fix(project): proper lighting extra name reference

* fix(mypy): remove options that is default now

* add back nox

* fix linting issues

* make torch import optional

---------

Co-authored-by: Ivan Shcheklein <shcheklein@gmail.com>
  • Loading branch information
dberenbaum and shcheklein committed Dec 21, 2023
1 parent a7b7f55 commit b8a8ecf
Show file tree
Hide file tree
Showing 8 changed files with 595 additions and 94 deletions.
315 changes: 315 additions & 0 deletions examples/DVCLive-Fabric.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "QKSE19fW_Dnj"
},
"source": [
"# DVCLive and Lightning Fabric"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q-C_4R_o_QGG"
},
"source": [
"## Install dvclive"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-XFbvwq7TSwN",
"outputId": "15d0e3b5-bb4a-4b3e-d37f-21608d1822ed"
},
"outputs": [],
"source": [
"!pip install \"dvclive[lightning]\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I6S6Uru1_Y0x"
},
"source": [
"## Initialize DVC Repository"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WcbvUl2uTV0y",
"outputId": "aff9740c-26db-483d-ce30-cfef395f3cbb"
},
"outputs": [],
"source": [
"!git init -q\n",
"!git config --local user.email \"you@example.com\"\n",
"!git config --local user.name \"Your Name\"\n",
"!dvc init -q\n",
"!git commit -m \"DVC init\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LmY4PLMh_cUk"
},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "85qErT5yTEbN"
},
"outputs": [],
"source": [
"import argparse\n",
"from os import path\n",
"from types import SimpleNamespace\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"import torchvision.transforms as T\n",
"from lightning.fabric import Fabric, seed_everything\n",
"from lightning.fabric.utilities.rank_zero import rank_zero_only\n",
"from torch.optim.lr_scheduler import StepLR\n",
"from torchmetrics.classification import Accuracy\n",
"from torchvision.datasets import MNIST\n",
"\n",
"from dvclive.fabric import DVCLiveLogger\n",
"\n",
"DATASETS_PATH = (\"Datasets\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UrmAHbhr_lgs"
},
"source": [
"## Setup model code\n",
"\n",
"Adapted from https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/fabric/image_classifier/train_fabric.py.\n",
"\n",
"Look for the `logger` statements where DVCLiveLogger calls were added."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UCzTygUnTHM8"
},
"outputs": [],
"source": [
"class Net(nn.Module):\n",
" def __init__(self) -> None:\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
" self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
" self.dropout1 = nn.Dropout(0.25)\n",
" self.dropout2 = nn.Dropout(0.5)\n",
" self.fc1 = nn.Linear(9216, 128)\n",
" self.fc2 = nn.Linear(128, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = F.relu(x)\n",
" x = self.conv2(x)\n",
" x = F.relu(x)\n",
" x = F.max_pool2d(x, 2)\n",
" x = self.dropout1(x)\n",
" x = torch.flatten(x, 1)\n",
" x = self.fc1(x)\n",
" x = F.relu(x)\n",
" x = self.dropout2(x)\n",
" x = self.fc2(x)\n",
" return F.log_softmax(x, dim=1)\n",
"\n",
"\n",
"def run(hparams):\n",
" # Create the DVCLive Logger\n",
" logger = DVCLiveLogger(report=\"notebook\")\n",
"\n",
" # Log dict of hyperparameters\n",
" logger.log_hyperparams(hparams.__dict__)\n",
"\n",
" # Create the Lightning Fabric object. The parameters like accelerator, strategy, devices etc. will be proided\n",
" # by the command line. See all options: `lightning run model --help`\n",
" fabric = Fabric()\n",
"\n",
" seed_everything(hparams.seed) # instead of torch.manual_seed(...)\n",
"\n",
" transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])\n",
"\n",
" # Let rank 0 download the data first, then everyone will load MNIST\n",
" with fabric.rank_zero_first(local=False): # set `local=True` if your filesystem is not shared between machines\n",
" train_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=True, transform=transform)\n",
" test_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=False, transform=transform)\n",
"\n",
" train_loader = torch.utils.data.DataLoader(\n",
" train_dataset,\n",
" batch_size=hparams.batch_size,\n",
" )\n",
" test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=hparams.batch_size)\n",
"\n",
" # don't forget to call `setup_dataloaders` to prepare for dataloaders for distributed training.\n",
" train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader)\n",
"\n",
" model = Net() # remove call to .to(device)\n",
" optimizer = optim.Adadelta(model.parameters(), lr=hparams.lr)\n",
"\n",
" # don't forget to call `setup` to prepare for model / optimizer for distributed training.\n",
" # the model is moved automatically to the right device.\n",
" model, optimizer = fabric.setup(model, optimizer)\n",
"\n",
" scheduler = StepLR(optimizer, step_size=1, gamma=hparams.gamma)\n",
"\n",
" # use torchmetrics instead of manually computing the accuracy\n",
" test_acc = Accuracy(task=\"multiclass\", num_classes=10).to(fabric.device)\n",
"\n",
" # EPOCH LOOP\n",
" for epoch in range(1, hparams.epochs + 1):\n",
" # TRAINING LOOP\n",
" model.train()\n",
" for batch_idx, (data, target) in enumerate(train_loader):\n",
" # NOTE: no need to call `.to(device)` on the data, target\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = F.nll_loss(output, target)\n",
" fabric.backward(loss) # instead of loss.backward()\n",
"\n",
" optimizer.step()\n",
" if (batch_idx == 0) or ((batch_idx + 1) % hparams.log_interval == 0):\n",
" print(\n",
" \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n",
" epoch,\n",
" batch_idx * len(data),\n",
" len(train_loader.dataset),\n",
" 100.0 * batch_idx / len(train_loader),\n",
" loss.item(),\n",
" )\n",
" )\n",
"\n",
" # Log dict of metrics\n",
" logger.log_metrics({\"loss\": loss.item()})\n",
"\n",
" if hparams.dry_run:\n",
" break\n",
"\n",
" scheduler.step()\n",
"\n",
" # TESTING LOOP\n",
" model.eval()\n",
" test_loss = 0\n",
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" # NOTE: no need to call `.to(device)` on the data, target\n",
" output = model(data)\n",
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
"\n",
" # WITHOUT TorchMetrics\n",
" # pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n",
" # correct += pred.eq(target.view_as(pred)).sum().item()\n",
"\n",
" # WITH TorchMetrics\n",
" test_acc(output, target)\n",
"\n",
" if hparams.dry_run:\n",
" break\n",
"\n",
" # all_gather is used to aggregated the value across processes\n",
" test_loss = fabric.all_gather(test_loss).sum() / len(test_loader.dataset)\n",
"\n",
" print(f\"\\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({100 * test_acc.compute():.0f}%)\\n\")\n",
"\n",
" # log additional metrics\n",
" logger.log_metrics({\"test_loss\": test_loss, \"test_acc\": 100 * test_acc.compute()})\n",
"\n",
" test_acc.reset()\n",
"\n",
" if hparams.dry_run:\n",
" break\n",
"\n",
" # When using distributed training, use `fabric.save`\n",
" # to ensure the current process is allowed to save a checkpoint\n",
" if hparams.save_model:\n",
" fabric.save(\"mnist_cnn.pt\", model.state_dict())\n",
"\n",
" # `logger.experiment` provides access to the `dvclive.Live` instance where you can use additional logging methods.\n",
" # Check that `rank_zero_only.rank == 0` to avoid logging in other processes.\n",
" if rank_zero_only.rank == 0:\n",
" logger.experiment.log_artifact(\"mnist_cnn.pt\")\n",
"\n",
" # Call finalize to save final results as a DVC experiment\n",
" logger.finalize(\"success\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o5_v9lRDAM7l"
},
"source": [
"## Train the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "BbCXen1PTM4V",
"outputId": "b79c90eb-74cc-474d-c0dd-21245064bca8"
},
"outputs": [],
"source": [
"hparams = SimpleNamespace(batch_size=64, epochs=5, lr=1.0, gamma=0.7, dry_run=False, seed=1, log_interval=10, save_model=True)\n",
"run(hparams)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DnqCrlbLAopV"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ fastai = ["fastai"]
lightning = ["lightning>=2.0", "torch"]
optuna = ["optuna"]
all = [
"dvclive[image,mmcv,tf,xgb,lgbm,huggingface,catalyst,fastai,pytorch-lightning,optuna,plots,markdown]"
"dvclive[image,mmcv,tf,xgb,lgbm,huggingface,catalyst,fastai,lightning,optuna,plots,markdown]"
]

[project.urls]
Expand Down

0 comments on commit b8a8ecf

Please sign in to comment.