Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance on single GPU is much better than on Multi-GPUs #2754

Open
3 of 4 tasks
baicenxiao opened this issue May 8, 2024 · 3 comments
Open
3 of 4 tasks

Performance on single GPU is much better than on Multi-GPUs #2754

baicenxiao opened this issue May 8, 2024 · 3 comments

Comments

@baicenxiao
Copy link

baicenxiao commented May 8, 2024

System Info

accelerate==0.17.1, python==3.9, pytorch==2.0.0, 20.04.1-Ubuntu

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

According to performance comparison guideline, if we use batch_size_multi_gpu in multiple GPUs scenario, then we can get similar performance if we use batch_size_single_gpu = batch_size_multi_gpu * num_GPUs in 1 GPU scenario.

But when I testing the official example code, setting batch_size=4 with single GPU training can give much better performance than setting batch_size=1 with 4 GPUs training. I disabled the learning rate scheduler in case the learning rate steps in different manner for distributed training.

What I changed in the official example script:

  1. remove the learning rate scheduler step in the code
  2. change batch_size: Use batch_size=1 for training with 4 GPUs and use batch_size=4 for training with single GPU.

FYI, below are the config and training output.

Single GPU:

(shadow) bxiao@ip-10-45-101-134:/sensei-fs/users/bxiao/test_multiGPUs$ accelerate config --config_file config.yaml
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------In which compute environment are you running?
This machine                                                                                                                                                                
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------Which type of machine are you using?                                                                                                                                        
No distributed training                                                                                                                                                     
Do you want to run your training on CPU only (even if a GPU / Apple Silicon device is available)? [yes/NO]:                                                                 
Do you wish to optimize your script with torch dynamo?[yes/NO]:                                                                                                             
Do you want to use DeepSpeed? [yes/NO]:                                                                                                                                     
What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:                                                                           
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------Do you wish to use FP16 or BF16 (mixed precision)?
no                                                                                                                                                                          
accelerate configuration saved at config.yaml                                                                                                                               
(shadow) bxiao@ip-10-45-101-134:/sensei-fs/users/bxiao/test_multiGPUs$ accelerate launch --config_file config.yaml ./cv_example.py --data_dir ./images                      
The following values were not passed to `accelerate launch` and had defaults used instead:                                                                                  
        `--dynamo_backend` was set to a value of `'no'`                                                                                                                     
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
0.17.1
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1478/1478 [00:28<00:00, 51.96it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:07<00:00, 47.86it/s]
epoch 0: 84.84
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1478/1478 [00:27<00:00, 53.54it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:07<00:00, 49.41it/s]
epoch 1: 88.43
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1478/1478 [00:27<00:00, 53.63it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:07<00:00, 47.57it/s]
epoch 2: 87.28

4 GPUs:

(shadow) bxiao@ip-10-45-101-134:/sensei-fs/users/bxiao/test_multiGPUs$ accelerate config --config_file config.yaml
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------In which compute environment are you running?
This machine                                                                                                                                                                
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------Which type of machine are you using?                                                                                                                                        
multi-GPU                                                                                                                                                                   
How many different machines will you use (use more than 1 for multi-node training)? [1]:                                                                                    
Do you wish to optimize your script with torch dynamo?[yes/NO]:                                                                                                             
Do you want to use DeepSpeed? [yes/NO]:                                                                                                                                     
Do you want to use FullyShardedDataParallel? [yes/NO]:                                                                                                                      
Do you want to use Megatron-LM ? [yes/NO]: 
How many GPU(s) should be used for distributed training? [1]:4
What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------Do you wish to use FP16 or BF16 (mixed precision)?
no                                                                                                                                                                          
accelerate configuration saved at config.yaml                                                                                                                               
(shadow) bxiao@ip-10-45-101-134:/sensei-fs/users/bxiao/test_multiGPUs$ accelerate launch --config_file config.yaml ./cv_example.py --data_dir ./images                      
The following values were not passed to `accelerate launch` and had defaults used instead:                                                                                  
        `--dynamo_backend` was set to a value of `'no'`                                                                                                                     
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1478/1478 [00:35<00:00, 41.53it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:10<00:00, 35.46it/s]
epoch 0: 63.67
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1478/1478 [00:34<00:00, 42.84it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:10<00:00, 35.51it/s]
epoch 1: 76.73
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1478/1478 [00:34<00:00, 42.36it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:10<00:00, 35.08it/s]
epoch 2: 76.32

Here is the script I modified:

# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
from tqdm import tqdm
import numpy as np
import PIL
import torch
from timm import create_model
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor

from accelerate import Accelerator
import accelerate
print(accelerate.__version__)


########################################################################
# This is a fully working simple example to use Accelerate
#
# This example trains a ResNet50 on the Oxford-IIT Pet Dataset
# in any of the following settings (with the same script):
#   - single CPU or single GPU
#   - multi GPUS (using PyTorch distributed mode)
#   - (multi) TPUs
#   - fp16 (mixed-precision) or fp32 (normal precision)
#
# To run it in each of these various modes, follow the instructions
# in the readme for examples:
# https://github.com/huggingface/accelerate/tree/main/examples
#
########################################################################


# Function to get the label from the filename
def extract_label(fname):
    stem = fname.split(os.path.sep)[-1]
    return re.search(r"^(.*)_\d+\.jpg$", stem).groups()[0]


class PetsDataset(Dataset):
    def __init__(self, file_names, image_transform=None, label_to_id=None):
        self.file_names = file_names
        self.image_transform = image_transform
        self.label_to_id = label_to_id

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        fname = self.file_names[idx]
        raw_image = PIL.Image.open(fname)
        image = raw_image.convert("RGB")
        if self.image_transform is not None:
            image = self.image_transform(image)
        label = extract_label(fname)
        if self.label_to_id is not None:
            label = self.label_to_id[label]
        return {"image": image, "label": label}


def training_function(config, args):
    # Initialize accelerator
    accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)

    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
    lr = config["lr"]
    num_epochs = int(config["num_epochs"])
    seed = int(config["seed"])
    batch_size = int(config["batch_size"])
    image_size = config["image_size"]
    if not isinstance(image_size, (list, tuple)):
        image_size = (image_size, image_size)

    # Grab all the image filenames
    file_names = [os.path.join(args.data_dir, fname) for fname in os.listdir(args.data_dir) if fname.endswith(".jpg")]

    # Build the label correspondences
    all_labels = [extract_label(fname) for fname in file_names]
    id_to_label = list(set(all_labels))
    id_to_label.sort()
    label_to_id = {lbl: i for i, lbl in enumerate(id_to_label)}

    # Set the seed before splitting the data.
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Split our filenames between train and validation
    random_perm = np.random.permutation(len(file_names))
    cut = int(0.8 * len(file_names))
    train_split = random_perm[:cut]
    eval_split = random_perm[cut:]

    # For training we use a simple RandomResizedCrop
    train_tfm = Compose([RandomResizedCrop(image_size, scale=(0.5, 1.0)), ToTensor()])
    train_dataset = PetsDataset(
        [file_names[i] for i in train_split], image_transform=train_tfm, label_to_id=label_to_id
    )

    # For evaluation, we use a deterministic Resize
    eval_tfm = Compose([Resize(image_size), ToTensor()])
    eval_dataset = PetsDataset([file_names[i] for i in eval_split], image_transform=eval_tfm, label_to_id=label_to_id)

    # Instantiate dataloaders.
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=4)
    eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=batch_size, num_workers=4)

    # Instantiate the model (we build the model here so that the seed also control new weights initialization)
    model = create_model("resnet50d", pretrained=True, num_classes=len(label_to_id))

    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).
    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer
    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).
    model = model.to(accelerator.device)

    # Freezing the base model
    for param in model.parameters():
        param.requires_grad = False
    for param in model.get_classifier().parameters():
        param.requires_grad = True

    # We normalize the batches of images to be a bit faster.
    mean = torch.tensor(model.default_cfg["mean"])[None, :, None, None].to(accelerator.device)
    std = torch.tensor(model.default_cfg["std"])[None, :, None, None].to(accelerator.device)

    # Instantiate optimizer
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr / 25)

    # Instantiate learning rate scheduler
    lr_scheduler = OneCycleLR(optimizer=optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=len(train_dataloader))

    # Prepare everything
    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
    # prepare method.
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )

    # Now we train the model
    for epoch in range(num_epochs):
        model.train()
        for step, batch in enumerate(tqdm(train_dataloader, disable=not accelerator.is_local_main_process)):
            # We could avoid this line since we set the accelerator with `device_placement=True`.
            batch = {k: v.to(accelerator.device) for k, v in batch.items()}
            inputs = (batch["image"] - mean) / std
            outputs = model(inputs)
            loss = torch.nn.functional.cross_entropy(outputs, batch["label"])
            accelerator.backward(loss)
            optimizer.step()
            # lr_scheduler.step()
            optimizer.zero_grad()

        model.eval()
        accurate = 0
        num_elems = 0
        for _, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
            # We could avoid this line since we set the accelerator with `device_placement=True`.
            batch = {k: v.to(accelerator.device) for k, v in batch.items()}
            inputs = (batch["image"] - mean) / std
            with torch.no_grad():
                outputs = model(inputs)
            predictions = outputs.argmax(dim=-1)
            predictions, references = accelerator.gather_for_metrics((predictions, batch["label"]))
            accurate_preds = predictions == references
            num_elems += accurate_preds.shape[0]
            accurate += accurate_preds.long().sum()

        eval_metric = accurate.item() / num_elems
        # Use accelerator.print to print only on the main process.
        accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")


def main():
    parser = argparse.ArgumentParser(description="Simple example of training script.")
    parser.add_argument("--data_dir", required=True, help="The data folder on disk.")
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16", "fp8"],
        help="Whether to use mixed precision. Choose"
        "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
        "and an Nvidia Ampere GPU.",
    )
    parser.add_argument(
        "--checkpointing_steps",
        type=str,
        default=None,
        help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
    )
    parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
    args = parser.parse_args()
    config = {"lr": 3e-2, "num_epochs": 3, "seed": 42, "batch_size": 1, "image_size": 224}
    training_function(config, args)


if __name__ == "__main__":
    main()

Expected behavior

When setting batch_size according to batch_size_single_gpu = batch_size_multi_gpu * num_GPUs, training with single GPU should give similar performance as training with multi GPUs.

@muellerzr
Copy link
Collaborator

Have you also tried scaling the learning rate according to the multiple GPUs? (What I mean by this is in multi-GPU the scheduler is stepped N times, which could account for some of this)

@baicenxiao
Copy link
Author

Hi @muellerzr, thanks for the response!

For the experiments above, I have already disabled the learning rate scheduler.

In addition, I have tried adjust the learning rate according to learning_rate *= accelerator.num_processes given in the official performance guideline. I still see a significant difference in the training performance.

FYI, here is the result after using learning_rate *= 4 when training with 4 GPUs:

(shadow) bxiao@ip-10-45-101-134:/sensei-fs/users/bxiao/test_multiGPUs$ accelerate launch --config_file config.yaml ./cv_example.py --data_dir ./images
The following values were not passed to `accelerate launch` and had defaults used instead:
        `--dynamo_backend` was set to a value of `'no'`
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
0.17.1
0.17.1
0.17.1
0.17.1
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1478/1478 [00:35<00:00, 41.07it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:10<00:00, 35.43it/s]
epoch 0: 75.24
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1478/1478 [00:34<00:00, 42.35it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:10<00:00, 36.49it/s]
epoch 1: 76.52
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1478/1478 [00:34<00:00, 42.67it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:10<00:00, 35.66it/s]
epoch 2: 77.33

@muellerzr
Copy link
Collaborator

Thanks, let me try running this today and see what happens

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants