Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer - add cache clearing and the option for batched eval metrics computation #28769

Merged
merged 20 commits into from
May 6, 2024

Conversation

FoamoftheSea
Copy link
Contributor

@FoamoftheSea FoamoftheSea commented Jan 30, 2024

What does this PR do?

This PR does two things which are necessary for using the Trainer in resource constrained environments (like my RTX-3070Ti machine):

  1. Add cache clearing in training and evaluation loops
    • This reduces peak GPU load and prevents CUDA OOM errors when running near capacity.
  2. Add Trainer arg batch_eval_metrics for batched eval metrics computation.
    • When working with limited RAM, storing all logits across the entire evaluation set may not be feasible. A user working in this condition can pass True to batch_eval_metrics and construct a compute_metrics function which can update average metrics at a batch level to prevent OOM errors with large eval sets. Particularly useful for vision transformers.
    • Previous functionality is unaltered if option is not set to True

@muellerzr

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker
Copy link
Collaborator

cc @pacman100 and @muellerzr

@FoamoftheSea
Copy link
Contributor Author

Hey everyone, I tried to look at the logs for the failed tests, but I don't see any actionable error reports. Can anyone help me figure out what needs to be done for them to pass?

@ArthurZucker
Copy link
Collaborator

The main CI is a bit broken because of pytest package. Let's wait a bit here

@ArthurZucker
Copy link
Collaborator

just re-ran the ci, you should actually rebase to main should be alright

@ArthurZucker
Copy link
Collaborator

BTW @SunMarc would be nice if you can have a look as well!

@FoamoftheSea
Copy link
Contributor Author

CIs are green after merging main ✔️

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @FoamoftheSea, thanks for contributing ! I left a few comments to better understand how you are performing the batched metric computation. Can you also add tests to see if we get the same result with/without batched computation.

Comment on lines +4033 to +4041
metrics = self.compute_metrics(
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host),
compute_result=is_last_step,
)
else:
metrics = self.compute_metrics(
EvalPrediction(predictions=preds_host, label_ids=labels_host),
compute_result=is_last_step,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't add a compute_result argument here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code path would only be used if the user set args.batch_eval_metrics to True, so only those trying to use this feature would need to worry about the expectation for this argument. I'm definitely open to other suggestions though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we settle on the right solution I will write a test for it 🙏

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will let @ArthurZucker and @muellerzr comment on that !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems okay to me if this path is only ever when a user has this enabled. (We should maybe write a snippet/clarification about it in the docstring in TrainingArguments)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I'll work on this and an appropriate test for it some time this week

Comment on lines +4033 to +4041
metrics = self.compute_metrics(
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host),
compute_result=is_last_step,
)
else:
metrics = self.compute_metrics(
EvalPrediction(predictions=preds_host, label_ids=labels_host),
compute_result=is_last_step,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you computing + storing the batched metric inside self.compute_metrics (need to be a class with __call__ defined ) ?

Copy link
Contributor Author

@FoamoftheSea FoamoftheSea Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually using a globally instanced metrics class that maintains state with update and compute methods inside of the compute_metrics function that is passed to the trainer, although there is probably a better solution, I found this was the least intrusive on the current expected behavior of the trainer code, since if the user does not activate the batched eval metrics option then they won't activate the code path where the compute_result argument is used, and therefore don't have to change anything about their previous compute_metrics functions.

Here's a very basic pseudo-code example of how I use this:

class MSEMetric:
    def __init__(self):
        self.batch_mse = []

    def update(self, preds, target):
        diff = target - preds
        batch_mse = np.mean(np.power(diff, 2))
        self.batch_mse.append(batch_mse)

    def compute(self):
        # Get result across entire eval set
        result = {"mse": np.mean(self.batch_mse)}
        # Reset batch statistics
        self.batch_mse = []
        return result

mse_metric = MSEMetric()
        
def compute_metrics(eval_pred, compute_result: bool = True) -> Optional[dict]:
    mse_metric.update(eval_pred.predictions, eval_pred.target)
        
    if compute_result:
        return mse_metric.compute()

# Use this compute_metrics fn in trainer
trainer = Trainer(compute_metrics=compute_metrics, ...)                    

This mirrors the update and compute methodology from the metrics classes in torcheval.metrics, which is where I got the inspiration from (see https://pytorch.org/torcheval/main/generated/torcheval.metrics.MeanSquaredError.html)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense ! Thanks for explaining

Comment on lines 3456 to 3472
if self.args.batch_eval_metrics:
if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
is_last_step = step == len(dataloader) - 1
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host),
compute_result=is_last_step,
)
else:
metrics = self.compute_metrics(
EvalPrediction(predictions=preds_host, label_ids=labels_host),
compute_result=is_last_step,
)
del losses_host, preds_host, inputs_host, labels_host
torch.cuda.empty_cache()
losses_host, preds_host, inputs_host, labels_host = None, None, None, None

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

@@ -3449,8 +3453,25 @@ def evaluation_loop(

self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)

if self.args.batch_eval_metrics:
if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
is_last_step = step == len(dataloader) - 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick note, not every dataset have __len__() defined such as IterableDataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we'll need to find a more robust way to identify the last step of the eval set then... If anyone has an idea let me know

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to use self.accelerator.gradient_state.end_of_dataloader here :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • adding these changes :)

@SunMarc SunMarc requested a review from muellerzr February 23, 2024 16:21
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall this seems quite handy. If we can confirm that it does reduce your memory footprint without issue then I believe that's quite alright. Replied to comments from Marc's review. Let's apply those then I can give a checkmark on my end at least!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@FoamoftheSea
Copy link
Contributor Author

Commenting to keep the PR fresh. I got super busy the past couple weeks but I will finish this soon.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ducha-aiki
Copy link
Contributor

If we can confirm that it does reduce your memory footprint without issue then I believe that's quite alright.

I will check this today, thank you!

@ducha-aiki
Copy link
Contributor

ducha-aiki commented Apr 24, 2024

This works amazing!
Here is the RAM consumption before:

image

And after:

image

The only thing - PR needs to be updated to work with 4.40, because all_labels logic has changed since 4.39. I haven't tried to update PR to 4.40, tested on the its own branch only

@FoamoftheSea
Copy link
Contributor Author

I am looking into this, it looks like the conflict is due to a different management of the variables in latest, which actually discards the use of the intermediate variable I was clearing from memory, so I want to double check on how that affects this change.

Also, I'm working on getting a test made for the batching functionality. Should have this ready soon.

@FoamoftheSea
Copy link
Contributor Author

The pytests are ready and I've updated the code to work with the latest Trainer updates. I still need to run an AB test to see if the cache clearing is still providing any benefit with the new changes from main, since they seem to have the potential to handle the same issue. I just need to finish setting up a test case. Let's hold off on merging until we have some fresh test results.

@FoamoftheSea
Copy link
Contributor Author

FoamoftheSea commented May 6, 2024

The test results demonstrate 2 things:

  1. batch_eval_metrics is essential when training models with large solution spaces such as semantic segmentation where the typical code path accumulates many large tensors in memory during the eval loop, and batching the calculations avoids this.
  2. The CUDA cache clearing keeps maximum GPU utilization lower over time, and allows using a larger eval batch size to expedite the process.

Based on these results, I think that we can justify these changes.

Training details:

  • Model: nvidia/segformer-b0-finetuned-cityscapes-1024-1024
  • Dataset: Antreas/Cityscapes

System details:

  • Windows 10 Pro
  • GPU = NVidia Quadro T2000
  • CPU = Intel Xeon 2.80GHz (12 CPUs)
  • RAM = 32GB

In the following chart, we can see that the standard code path goes OOM and fails as it tries to store all of the dense logits in memory on either the GPU or the CPU (as in the case of using eval_accumulation_steps). Only the run using batch_eval_metrics survives the evaluation cycle without going OOM and failing.

Further, we can see there is a great boost in memory efficiency during the train and eval phases using the cache clearing, which leads to very slightly lower iteration speed, but leaves a lot of headroom for using a larger batch size during evaluation to compensate.

W B Chart 5_5_2024, 6_47_43 PM
W B Chart 5_5_2024, 6_48_29 PM
W B Chart 5_5_2024, 6_48_54 PM
W B Chart 5_5_2024, 6_49_16 PM

@LysandreJik
Copy link
Member

Awesome, great work @FoamoftheSea!

Approving the PR, @muellerzr feel free to merge at your convenience.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all your hard work with this!

@muellerzr muellerzr merged commit df475bf into huggingface:main May 6, 2024
20 checks passed
@Reveyer
Copy link

Reveyer commented May 29, 2024

Hello, I've noticed that this pull request seems to slow down the speed when using the trainer, likely due to the frequent use of torch.cuda.empty_cache(). Is there a way to optimize this, or could we possibly have the option to choose whether or not to use torch.cuda.empty_cache()?"@muellerzr

@muellerzr
Copy link
Contributor

muellerzr commented May 29, 2024

Do you have a small benchmark for us @Reveyer I haven’t noticed this yet when I was investigating another timing issue. But would be happy took into this

@Reveyer
Copy link

Reveyer commented May 29, 2024

@muellerzr I was able to reproduce this on the official code. Here’s the command I used:

CUDA_VISIBLE_DEVICES=0 python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path facebook/bart-large \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=16 \
    --per_device_eval_batch_size=16 \
    --overwrite_output_dir \
    --predict_with_generate \
    --num_train_epochs="10" \
    --seed="42"

Here are the results with and without torch.cuda.empty_cache() removed from the trainer:

  • With torch.cuda.empty_cache() removed:
    | 30/179450 [00:22<37:48:51, 1.32it/s]
  • With torch.cuda.empty_cache():
    | 30/179450 [00:23<39:22:17, 1.27it/s]

The discrepancy is even more significant in my personal code, which utilizes Llama-3:

  • With torch.cuda.empty_cache() removed:
    | 5/89543 [00:43<216:51:58, 8.72s/it]
  • With torch.cuda.empty_cache():
    | 5/89543 [00:51<251:54:05, 10.13s/it]

I hope this helps clarify the issue. Looking forward to hearing your thoughts!

@FoamoftheSea
Copy link
Contributor Author

@muellerzr @Reveyer I would test the speed difference on the second training iteration (after first eval round). This change sacrificed some speed on the first training iteration for noticeable increase in speed in the second and onward after that. Something about the eval round clogged up memory and made all subsequent training loops very slow, the cache emptying fixed that at a slight decrease in initial speed.

@Reveyer
Copy link

Reveyer commented May 30, 2024

@FoamoftheSea Hello, I've identified that these two lines of code are key to the speed differences observed. Could you please test the impact of these two lines on your experiments? Thank you!

del inputs
torch.cuda.empty_cache()

@muellerzr
Copy link
Contributor

Looks like we also have a CPU leak in here too when using custom evals, so I'll be investigating this today.

@FoamoftheSea
Copy link
Contributor Author

@Reveyer sorry for the late response, the past few weeks have been very busy. I will find some time this week to re-run the experiment with that change.

@FoamoftheSea
Copy link
Contributor Author

I ran my test again on the main branch today. The results I got with/without the cache clearing in the training loop are mostly identical, other than that the run with no cache clearing was slightly less memory efficient, and in my case very slightly slower over time, but I think that might just be because my GPU was already warmed up for that training run.

Seeing how the difference here is minimal, I would have no problem with reverting the line in the training loop for cache clearing, since it doesn't seem to cause the memory overflow problem that the eval loop does without it.

W B Chart 6_19_2024, 8_57_45 PM
W B Chart 6_19_2024, 8_58_55 PM

Here's the script I ran the experiment with:

from collections import Counter

import torch

import numpy as np

from typing import Mapping, Optional, Dict, Set
from datasets import load_dataset
from torchvision.transforms.functional import pil_to_tensor

from transformers import (
    SegformerForSemanticSegmentation,
    SegformerImageProcessor,
    Trainer,
    TrainingArguments,
    EvalPrediction,
)

from cityscapesscripts.helpers.labels import id2label

id2trainId = {k: v.trainId for k, v in id2label.items() if k >= 0}
trainId2label = {v.trainId: v.name for k, v in id2label.items() if k >= 0}
conversion_lookup = np.array([id2trainId[i] for i in range(len(id2trainId))])

dataset = load_dataset("Antreas/Cityscapes")
train_dataset, eval_dataset = dataset["train"], dataset["val"]

image_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")

BATCH_EVAL_METRICS = True


class SegformerSemanticSegEvalMetric:
    def __init__(
            self,
            id2label: Dict[int, str],
            ignore_class_ids: Optional[Set[int]] = None,
            reduced_labels: bool = False,
            batch_eval_metrics: bool = True,
    ):
        self.total_area_intersect = Counter()
        self.total_area_union = Counter()
        self.total_label_area = Counter()
        self.ignore_class_ids = ignore_class_ids or set()
        self.reduced_labels = reduced_labels
        self.id2label = id2label
        self.batch_eval_metrics = batch_eval_metrics

    def update(self, logits: torch.FloatTensor, gt_labels: torch.LongTensor):

        # logits_tensor = torch.from_numpy(logits)
        # scale the logits to the size of the label
        logits_tensor = torch.nn.functional.interpolate(
            logits,
            size=gt_labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        gt_labels = gt_labels.detach().cpu().numpy()

        for class_id in self.id2label.keys():
            if class_id in self.ignore_class_ids:
                continue
            if self.reduced_labels:
                label_id = class_id - 1 if class_id != 0 else 255
            else:
                label_id = class_id
            pred_pixels = pred_labels == label_id
            gt_pixels = gt_labels == label_id
            class_label = self.id2label[class_id]
            self.total_area_intersect.update({class_label: np.sum(np.bitwise_and(pred_pixels, gt_pixels))})
            self.total_area_union.update({class_label: np.sum(np.bitwise_or(pred_pixels, gt_pixels))})
            self.total_label_area.update({class_label: np.sum(gt_pixels)})

    def compute(self):
        accuracies = {f"accuracy_{k}": self.total_area_intersect[k] / self.total_label_area[k] for k in self.total_area_union}
        ious = {f"iou_{k}": self.total_area_intersect[k] / self.total_area_union[k] for k in self.total_area_union}
        metrics = {
            "overall_accuracy": sum(self.total_area_intersect.values()) / sum(self.total_label_area.values()),
            "mean_accuracy": np.mean(list(accuracies.values())),
            "mean_iou": np.mean(list(ious.values())),
        }
        metrics.update(accuracies)
        metrics.update(ious)

        return metrics

    def __call__(self, eval_pred: EvalPrediction, compute_result=False):
        if self.batch_eval_metrics:
            return self._call_batched(eval_pred, compute_result)
        else:
            return self._call_nonbatched(eval_pred)

    def _call_nonbatched(self, eval_pred):
        mious = {}
        with torch.no_grad():
            logits, gt_labels = eval_pred.predictions, eval_pred.label_ids
            logits_tensor = torch.from_numpy(logits)
            # scale the logits to the size of the label
            logits_tensor = torch.nn.functional.interpolate(
                logits_tensor,
                size=gt_labels.shape[-2:],
                mode="bilinear",
                align_corners=False,
            ).argmax(dim=1)

            pred_labels = logits_tensor.detach().cpu().numpy()

            for class_id in self.id2label.keys():
                if class_id in self.ignore_class_ids:
                    continue
                if self.reduced_labels:
                    label_id = class_id - 1 if class_id != 0 else 255
                else:
                    label_id = class_id
                pred_pixels = pred_labels == label_id
                gt_pixels = gt_labels == label_id
                class_label = self.id2label[class_id]
                intersection = np.sum(np.bitwise_and(pred_pixels, gt_pixels))
                union = np.sum(np.bitwise_or(pred_pixels, gt_pixels))
                mious[class_label] = intersection / union

        return np.mean(list(mious.values()))

    def _call_batched(self, eval_pred: EvalPrediction, compute_result: bool = True) -> Optional[dict]:
        with torch.no_grad():
            self.update(eval_pred.predictions, eval_pred.label_ids)
            return self.compute() if compute_result else None


def collate_fn(features: list):

    if not isinstance(features[0], Mapping):
        features = [vars(f) for f in features]
    first = features[0]
    images = None
    semantic_masks = None

    if "semantic_segmentation" in first and first["semantic_segmentation"] is not None:
        semantic_masks = [pil_to_tensor(f["semantic_segmentation"])[0] for f in features]
    if "image" in first and first["image"] is not None:
        images = [pil_to_tensor(f["image"].convert("RGB")) for f in features]

    processed = image_processor(images=images, segmentation_maps=semantic_masks)
    labels = np.array(processed.data["labels"])
    labels = conversion_lookup[labels]

    batch = {
        "pixel_values": torch.Tensor(np.array(processed.data["pixel_values"])),
        "labels": torch.LongTensor(np.array(labels)),
    }

    return batch


training_args = TrainingArguments(
    "segformer-test",
    num_train_epochs=10,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    save_total_limit=3,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=300,
    eval_steps=150,
    max_steps=10000,
    logging_steps=1,
    load_best_model_at_end=True,
    push_to_hub=False,
    gradient_checkpointing=False,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    use_cpu=False,
    learning_rate=0.0002,
    batch_eval_metrics=BATCH_EVAL_METRICS,
    remove_unused_columns=False,
    # eval_accumulation_steps=1 if not BATCH_EVAL_METRICS else None,
    run_name="segformer-b0-cityscapes-t4e8-batch-eval-metrics-no-cache-clear-train-loop",
)

trainer = Trainer(
    model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=SegformerSemanticSegEvalMetric(id2label=trainId2label, ignore_class_ids={255}),
    data_collator=collate_fn,
)

trainer.train()

@SunMarc
Copy link
Member

SunMarc commented Jun 20, 2024

Hi @FoamoftheSea, thanks for re-running the experiments ! From your observation, I think it will be better to remove cache clearing in the training loop as other users show a huge increase of training time. Would you like to open a PR ? Otherwise, I'll do it !

@manoja328
Copy link

any solution to this issue? My trainer runs fine during train model but the same crashes when run with eval mode.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants