-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trainer - add cache clearing and the option for batched eval metrics computation #28769
Conversation
…dates
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
cc @pacman100 and @muellerzr |
Hey everyone, I tried to look at the logs for the failed tests, but I don't see any actionable error reports. Can anyone help me figure out what needs to be done for them to pass? |
The main CI is a bit broken because of |
just re-ran the ci, you should actually rebase to main should be alright |
BTW @SunMarc would be nice if you can have a look as well! |
CIs are green after merging main ✔️ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @FoamoftheSea, thanks for contributing ! I left a few comments to better understand how you are performing the batched metric computation. Can you also add tests to see if we get the same result with/without batched computation.
metrics = self.compute_metrics( | ||
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host), | ||
compute_result=is_last_step, | ||
) | ||
else: | ||
metrics = self.compute_metrics( | ||
EvalPrediction(predictions=preds_host, label_ids=labels_host), | ||
compute_result=is_last_step, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can't add a compute_result
argument here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code path would only be used if the user set args.batch_eval_metrics to True, so only those trying to use this feature would need to worry about the expectation for this argument. I'm definitely open to other suggestions though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we settle on the right solution I will write a test for it 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will let @ArthurZucker and @muellerzr comment on that !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems okay to me if this path is only ever when a user has this enabled. (We should maybe write a snippet/clarification about it in the docstring in TrainingArguments)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! I'll work on this and an appropriate test for it some time this week
metrics = self.compute_metrics( | ||
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host), | ||
compute_result=is_last_step, | ||
) | ||
else: | ||
metrics = self.compute_metrics( | ||
EvalPrediction(predictions=preds_host, label_ids=labels_host), | ||
compute_result=is_last_step, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you computing + storing the batched metric inside self.compute_metrics (need to be a class with __call__
defined ) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually using a globally instanced metrics class that maintains state with update
and compute
methods inside of the compute_metrics
function that is passed to the trainer, although there is probably a better solution, I found this was the least intrusive on the current expected behavior of the trainer code, since if the user does not activate the batched eval metrics option then they won't activate the code path where the compute_result
argument is used, and therefore don't have to change anything about their previous compute_metrics functions.
Here's a very basic pseudo-code example of how I use this:
class MSEMetric:
def __init__(self):
self.batch_mse = []
def update(self, preds, target):
diff = target - preds
batch_mse = np.mean(np.power(diff, 2))
self.batch_mse.append(batch_mse)
def compute(self):
# Get result across entire eval set
result = {"mse": np.mean(self.batch_mse)}
# Reset batch statistics
self.batch_mse = []
return result
mse_metric = MSEMetric()
def compute_metrics(eval_pred, compute_result: bool = True) -> Optional[dict]:
mse_metric.update(eval_pred.predictions, eval_pred.target)
if compute_result:
return mse_metric.compute()
# Use this compute_metrics fn in trainer
trainer = Trainer(compute_metrics=compute_metrics, ...)
This mirrors the update
and compute
methodology from the metrics classes in torcheval.metrics
, which is where I got the inspiration from (see https://pytorch.org/torcheval/main/generated/torcheval.metrics.MeanSquaredError.html)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense ! Thanks for explaining
src/transformers/trainer.py
Outdated
if self.args.batch_eval_metrics: | ||
if self.compute_metrics is not None and preds_host is not None and labels_host is not None: | ||
is_last_step = step == len(dataloader) - 1 | ||
if args.include_inputs_for_metrics: | ||
metrics = self.compute_metrics( | ||
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host), | ||
compute_result=is_last_step, | ||
) | ||
else: | ||
metrics = self.compute_metrics( | ||
EvalPrediction(predictions=preds_host, label_ids=labels_host), | ||
compute_result=is_last_step, | ||
) | ||
del losses_host, preds_host, inputs_host, labels_host | ||
torch.cuda.empty_cache() | ||
losses_host, preds_host, inputs_host, labels_host = None, None, None, None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above
src/transformers/trainer.py
Outdated
@@ -3449,8 +3453,25 @@ def evaluation_loop( | |||
|
|||
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) | |||
|
|||
if self.args.batch_eval_metrics: | |||
if self.compute_metrics is not None and preds_host is not None and labels_host is not None: | |||
is_last_step = step == len(dataloader) - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quick note, not every dataset have __len__()
defined such as IterableDataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, we'll need to find a more robust way to identify the last step of the eval set then... If anyone has an idea let me know
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to use self.accelerator.gradient_state.end_of_dataloader
here :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- adding these changes :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Overall this seems quite handy. If we can confirm that it does reduce your memory footprint without issue then I believe that's quite alright. Replied to comments from Marc's review. Let's apply those then I can give a checkmark on my end at least!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Commenting to keep the PR fresh. I got super busy the past couple weeks but I will finish this soon. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
I will check this today, thank you! |
I am looking into this, it looks like the conflict is due to a different management of the variables in latest, which actually discards the use of the intermediate variable I was clearing from memory, so I want to double check on how that affects this change. Also, I'm working on getting a test made for the batching functionality. Should have this ready soon. |
# Conflicts: # src/transformers/trainer.py # src/transformers/training_args.py
The pytests are ready and I've updated the code to work with the latest Trainer updates. I still need to run an AB test to see if the cache clearing is still providing any benefit with the new changes from main, since they seem to have the potential to handle the same issue. I just need to finish setting up a test case. Let's hold off on merging until we have some fresh test results. |
The test results demonstrate 2 things:
Based on these results, I think that we can justify these changes. Training details:
System details:
In the following chart, we can see that the standard code path goes OOM and fails as it tries to store all of the dense logits in memory on either the GPU or the CPU (as in the case of using Further, we can see there is a great boost in memory efficiency during the train and eval phases using the cache clearing, which leads to very slightly lower iteration speed, but leaves a lot of headroom for using a larger batch size during evaluation to compensate. |
Awesome, great work @FoamoftheSea! Approving the PR, @muellerzr feel free to merge at your convenience. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all your hard work with this!
Hello, I've noticed that this pull request seems to slow down the speed when using the trainer, likely due to the frequent use of |
Do you have a small benchmark for us @Reveyer I haven’t noticed this yet when I was investigating another timing issue. But would be happy took into this |
@muellerzr I was able to reproduce this on the official code. Here’s the command I used:
Here are the results with and without
The discrepancy is even more significant in my personal code, which utilizes Llama-3:
I hope this helps clarify the issue. Looking forward to hearing your thoughts! |
@muellerzr @Reveyer I would test the speed difference on the second training iteration (after first eval round). This change sacrificed some speed on the first training iteration for noticeable increase in speed in the second and onward after that. Something about the eval round clogged up memory and made all subsequent training loops very slow, the cache emptying fixed that at a slight decrease in initial speed. |
@FoamoftheSea Hello, I've identified that these two lines of code are key to the speed differences observed. Could you please test the impact of these two lines on your experiments? Thank you! transformers/src/transformers/trainer.py Lines 3270 to 3271 in 2b9e252
|
Looks like we also have a CPU leak in here too when using custom evals, so I'll be investigating this today. |
@Reveyer sorry for the late response, the past few weeks have been very busy. I will find some time this week to re-run the experiment with that change. |
I ran my test again on the main branch today. The results I got with/without the cache clearing in the training loop are mostly identical, other than that the run with no cache clearing was slightly less memory efficient, and in my case very slightly slower over time, but I think that might just be because my GPU was already warmed up for that training run. Seeing how the difference here is minimal, I would have no problem with reverting the line in the training loop for cache clearing, since it doesn't seem to cause the memory overflow problem that the eval loop does without it. Here's the script I ran the experiment with: from collections import Counter
import torch
import numpy as np
from typing import Mapping, Optional, Dict, Set
from datasets import load_dataset
from torchvision.transforms.functional import pil_to_tensor
from transformers import (
SegformerForSemanticSegmentation,
SegformerImageProcessor,
Trainer,
TrainingArguments,
EvalPrediction,
)
from cityscapesscripts.helpers.labels import id2label
id2trainId = {k: v.trainId for k, v in id2label.items() if k >= 0}
trainId2label = {v.trainId: v.name for k, v in id2label.items() if k >= 0}
conversion_lookup = np.array([id2trainId[i] for i in range(len(id2trainId))])
dataset = load_dataset("Antreas/Cityscapes")
train_dataset, eval_dataset = dataset["train"], dataset["val"]
image_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")
BATCH_EVAL_METRICS = True
class SegformerSemanticSegEvalMetric:
def __init__(
self,
id2label: Dict[int, str],
ignore_class_ids: Optional[Set[int]] = None,
reduced_labels: bool = False,
batch_eval_metrics: bool = True,
):
self.total_area_intersect = Counter()
self.total_area_union = Counter()
self.total_label_area = Counter()
self.ignore_class_ids = ignore_class_ids or set()
self.reduced_labels = reduced_labels
self.id2label = id2label
self.batch_eval_metrics = batch_eval_metrics
def update(self, logits: torch.FloatTensor, gt_labels: torch.LongTensor):
# logits_tensor = torch.from_numpy(logits)
# scale the logits to the size of the label
logits_tensor = torch.nn.functional.interpolate(
logits,
size=gt_labels.shape[-2:],
mode="bilinear",
align_corners=False,
).argmax(dim=1)
pred_labels = logits_tensor.detach().cpu().numpy()
gt_labels = gt_labels.detach().cpu().numpy()
for class_id in self.id2label.keys():
if class_id in self.ignore_class_ids:
continue
if self.reduced_labels:
label_id = class_id - 1 if class_id != 0 else 255
else:
label_id = class_id
pred_pixels = pred_labels == label_id
gt_pixels = gt_labels == label_id
class_label = self.id2label[class_id]
self.total_area_intersect.update({class_label: np.sum(np.bitwise_and(pred_pixels, gt_pixels))})
self.total_area_union.update({class_label: np.sum(np.bitwise_or(pred_pixels, gt_pixels))})
self.total_label_area.update({class_label: np.sum(gt_pixels)})
def compute(self):
accuracies = {f"accuracy_{k}": self.total_area_intersect[k] / self.total_label_area[k] for k in self.total_area_union}
ious = {f"iou_{k}": self.total_area_intersect[k] / self.total_area_union[k] for k in self.total_area_union}
metrics = {
"overall_accuracy": sum(self.total_area_intersect.values()) / sum(self.total_label_area.values()),
"mean_accuracy": np.mean(list(accuracies.values())),
"mean_iou": np.mean(list(ious.values())),
}
metrics.update(accuracies)
metrics.update(ious)
return metrics
def __call__(self, eval_pred: EvalPrediction, compute_result=False):
if self.batch_eval_metrics:
return self._call_batched(eval_pred, compute_result)
else:
return self._call_nonbatched(eval_pred)
def _call_nonbatched(self, eval_pred):
mious = {}
with torch.no_grad():
logits, gt_labels = eval_pred.predictions, eval_pred.label_ids
logits_tensor = torch.from_numpy(logits)
# scale the logits to the size of the label
logits_tensor = torch.nn.functional.interpolate(
logits_tensor,
size=gt_labels.shape[-2:],
mode="bilinear",
align_corners=False,
).argmax(dim=1)
pred_labels = logits_tensor.detach().cpu().numpy()
for class_id in self.id2label.keys():
if class_id in self.ignore_class_ids:
continue
if self.reduced_labels:
label_id = class_id - 1 if class_id != 0 else 255
else:
label_id = class_id
pred_pixels = pred_labels == label_id
gt_pixels = gt_labels == label_id
class_label = self.id2label[class_id]
intersection = np.sum(np.bitwise_and(pred_pixels, gt_pixels))
union = np.sum(np.bitwise_or(pred_pixels, gt_pixels))
mious[class_label] = intersection / union
return np.mean(list(mious.values()))
def _call_batched(self, eval_pred: EvalPrediction, compute_result: bool = True) -> Optional[dict]:
with torch.no_grad():
self.update(eval_pred.predictions, eval_pred.label_ids)
return self.compute() if compute_result else None
def collate_fn(features: list):
if not isinstance(features[0], Mapping):
features = [vars(f) for f in features]
first = features[0]
images = None
semantic_masks = None
if "semantic_segmentation" in first and first["semantic_segmentation"] is not None:
semantic_masks = [pil_to_tensor(f["semantic_segmentation"])[0] for f in features]
if "image" in first and first["image"] is not None:
images = [pil_to_tensor(f["image"].convert("RGB")) for f in features]
processed = image_processor(images=images, segmentation_maps=semantic_masks)
labels = np.array(processed.data["labels"])
labels = conversion_lookup[labels]
batch = {
"pixel_values": torch.Tensor(np.array(processed.data["pixel_values"])),
"labels": torch.LongTensor(np.array(labels)),
}
return batch
training_args = TrainingArguments(
"segformer-test",
num_train_epochs=10,
per_device_train_batch_size=4,
per_device_eval_batch_size=8,
save_total_limit=3,
eval_strategy="steps",
save_strategy="steps",
save_steps=300,
eval_steps=150,
max_steps=10000,
logging_steps=1,
load_best_model_at_end=True,
push_to_hub=False,
gradient_checkpointing=False,
gradient_checkpointing_kwargs={"use_reentrant": False},
use_cpu=False,
learning_rate=0.0002,
batch_eval_metrics=BATCH_EVAL_METRICS,
remove_unused_columns=False,
# eval_accumulation_steps=1 if not BATCH_EVAL_METRICS else None,
run_name="segformer-b0-cityscapes-t4e8-batch-eval-metrics-no-cache-clear-train-loop",
)
trainer = Trainer(
model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=SegformerSemanticSegEvalMetric(id2label=trainId2label, ignore_class_ids={255}),
data_collator=collate_fn,
)
trainer.train() |
Hi @FoamoftheSea, thanks for re-running the experiments ! From your observation, I think it will be better to remove cache clearing in the training loop as other users show a huge increase of training time. Would you like to open a PR ? Otherwise, I'll do it ! |
any solution to this issue? My trainer runs fine during train model but the same crashes when run with eval mode. |
What does this PR do?
This PR does two things which are necessary for using the Trainer in resource constrained environments (like my RTX-3070Ti machine):
batch_eval_metrics
for batched eval metrics computation.True
tobatch_eval_metrics
and construct acompute_metrics
function which can update average metrics at a batch level to prevent OOM errors with large eval sets. Particularly useful for vision transformers.True
@muellerzr