A partial, empirical answer to my question on Cross Validated.
See paper.pdf
.
Slides (please play the slideshow instead of scrolling through slides).
In researcher terms
There's a new, hot, few-shot, NLP benchmark on the block. Alice submits her model to the
leaderboard and gets SOTA accuracy
In engineer terms
Andy: Hey team, I'm lookin at the notebook for our new model by @Barbie, and I see:
test_set_accuracy = (
llm
.pretrain(df_test["text"])
.train(df_train["text"], df_train["label"])
.evaluate(df_test["text"], df_test["label"])
)
Barbie: it should be fine bc i didnt do:
llm.train(df_test["text"], df_test["label"])
Andy: Interesting. I'm not sure if it's ok to pretrain on unlabeled test set text like that. Could
test_set_accuracy
be higher than what we'll see in production?
Barbie: 🤔
-
Clone repo
git clone https://github.com/kddubey/pretrain-on-test.git
-
cd to the repo
cd pretrain-on-test
-
Install dependencies (in a virtual environment)
python -m pip install .
Reproduce the experiment results by running ./experiment.sh
on a T4
GPU, which will take roughly 50 hours to finish.
Default batch sizes are set to fit on a single T4 GPU. For some datasets, the batch sizes needed to be decreased.
To analyze the accuracy data, see analysis/
.
Terminal (local)
python run.py --help
For a quick, CPU-friendly, local run:
./experiment_mini.sh
Notebook (local)
The terminal output is quite verbose. For minimal but sufficient info, run this in a notebook.
from run import run, Experiment
experiment = Experiment(lm_type="bert", dataset_names=...)
run(experiment)
For a quick, CPU-friendly, local run:
from run import run, Experiment
experiment = Experiment(
lm_type="bert-tiny",
dataset_names=["ag_news", "SetFit/amazon_counterfactual_en"],
num_subsamples=1,
num_train=10,
num_test=10,
num_train_epochs_classification=1,
num_train_epochs_pretrain=1,
per_device_train_batch_size_pretrain=4,
per_device_train_batch_size_classification=4,
per_device_eval_batch_size_classification=4,
)
run(experiment)
Google Cloud Platform
Other cloud providers
Other cloud providers are not yet supported, sorry.
To support them, implement logging and file uploading functionality. See
cloud.py
. Then update cloud_provider_to_create_data_handlers
in
run.py
.
You'll probably find
./cloud_scripts/_setup_python_env.sh
useful
for cloud runs. Note that it assumes that the bucket name is
pretrain-on-test-accuracies
, and that the GPU image you're using already has Python
3.10+, pip, and venv/conda on it.