Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The model requires num_beams, although it is not needed in the example #105

Open
LEv145 opened this issue Feb 16, 2023 · 3 comments
Open

Comments

@LEv145
Copy link

LEv145 commented Feb 16, 2023

Ubuntu 20.04
pytorch==1.11.0a0+17540c5c
NVIDIA CUDA 11.6.0
TensorRT 8.2.3
transformers==4.26.1
apex NVIDIA/apex@0c8400a or (qywu/apex@798a36c with patch _amp_state.py)
deepspeed==0.8.0
triton==1.0.0
timm==0.3.2

Code:

import os


os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "5000"
os.environ["USE_DEEPSPEED"] = "1"


from src.xl_wrapper import RuGPT3XL



gpt = RuGPT3XL.from_pretrained(
    "sberbank-ai/rugpt3xl",
    weights_path="/mnt/store/models/rugpt3xl/mp_rank_00_model_states.pt",
    seq_len=512,
)
gpt.generate(
    (
        "\u041a\u0442\u043e \u0431\u044b\u043b \u043f\u0440\u0435\u0437\u0438"
        "\u0434\u0435\u043d\u0442\u043e\u043c \u0421\u0428\u0410 \u0432 2020?"
    ),
    max_length=50,
    no_repeat_ngram_size=3,
    repetition_penalty=2.0,
)

Error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/rugpts/src/xl_wrapper.py", line 224, in generate
    res = super().generate(
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/generation/utils.py", line 1331, in generate
    (generation_config.num_beams > 1)
TypeError: '>' not supported between instances of 'NoneType' and 'int'

I don't know what num_beams does and how to make it work, but I would be happy to help

Pip freeze
absl-py==1.0.0
alabaster==0.7.12
apex==0.1
appdirs==1.4.4
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1618968359944/work
astunparse==1.6.3
attrs==21.4.0
audioread==2.1.9
Babel==2.9.1
backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1631087867185/work
black @ file:///home/conda/feedstock_root/build_artifacts/black-recipe_1643636307408/work
bleach==4.1.0
blis @ file:///home/conda/feedstock_root/build_artifacts/cython-blis_1636053204017/work
boto3==1.11.11
botocore==1.14.17
brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1636012188166/work
cachetools==5.0.0
catalogue @ file:///home/conda/feedstock_root/build_artifacts/catalogue_1638867392804/work
certifi==2021.10.8
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1636046063618/work
chardet @ file:///home/conda/feedstock_root/build_artifacts/chardet_1635814844635/work
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1638815705608/work
click @ file:///home/conda/feedstock_root/build_artifacts/click_1635822600067/work
cloudpickle==2.0.0
codecov==2.1.12
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1602866480661/work
conda==4.11.0
conda-build==3.21.8
conda-package-handling @ file:///home/conda/feedstock_root/build_artifacts/conda-package-handling_1636021700973/work
coverage==6.3.1
cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1639699280509/work
cudf @ file:///rapids/cudf-21.12.0a0%2B293.g0930f712e6-cp38-cp38-linux_x86_64.whl
cugraph @ file:///rapids/cugraph-21.12.0a0%2B95.g4b8c1330-cp38-cp38-linux_x86_64.whl
cuml @ file:///rapids/cuml-21.12.0a0%2B116.g4ce5bd609-cp38-cp38-linux_x86_64.whl
cupy-cuda115 @ file:///rapids/cupy_cuda115-9.6.0-cp38-cp38-manylinux1_x86_64.whl
cycler==0.11.0
cymem @ file:///home/conda/feedstock_root/build_artifacts/cymem_1636053152744/work
Cython==0.29.27
dask @ file:///rapids/dask-2021.11.2-py3-none-any.whl
dask-cuda @ file:///rapids/dask_cuda-21.12.0-py3-none-any.whl
dask-cudf @ file:///rapids/dask_cudf-21.12.0a0%2B293.g0930f712e6-py3-none-any.whl
dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work
debugpy==1.5.1
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
deepspeed==0.8.0
defusedxml==0.7.1
distributed @ file:///rapids/distributed-2021.11.2-py3-none-any.whl
docutils==0.15.2
entrypoints==0.3
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1633213722787/work
expecttest==0.1.3
fastrlock==0.8
filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1641470428964/work
flake8==3.7.9
Flask==2.0.3
flatbuffers==23.1.21
fonttools==4.29.1
fsspec==2022.1.0
future==0.18.2
gast==0.4.0
glob2==0.7
google-auth==2.6.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
graphsurgeon @ file:///workspace/TensorRT-8.2.3.0/graphsurgeon/graphsurgeon-0.4.5-py2.py3-none-any.whl
grpcio==1.43.0
h5py==3.8.0
HeapDict==1.0.1
hjson==3.1.0
huggingface-hub==0.12.0
hypothesis==4.50.8
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1609836280497/work
imagesize==1.3.0
importlib-metadata==4.11.1
importlib-resources==5.4.0
iniconfig==1.1.1
ipykernel==6.9.0
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1642613634924/work
ipython-genutils==0.2.0
itsdangerous==2.0.1
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1637175084646/work
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1636510082894/work
jmespath==0.10.0
joblib==1.1.0
json5==0.9.6
jsonschema==4.4.0
jupyter-client==7.1.2
jupyter-core==4.9.1
jupyter-tensorboard @ git+https://github.com/cliffwoolley/jupyter_tensorboard.git@ffa7e26138b82549453306e06b535a9ac36db17a
jupyterlab==2.3.2
jupyterlab-pygments==0.1.2
jupyterlab-server==1.2.0
jupytext==1.13.7
keras==2.11.0
kiwisolver==1.3.2
langcodes @ file:///home/conda/feedstock_root/build_artifacts/langcodes_1636741340529/work
libarchive-c @ file:///home/conda/feedstock_root/build_artifacts/python-libarchive-c_1643045750800/work
libclang==15.0.6.1
librosa==0.9.0
llvmlite==0.36.0
lmdb==1.3.0
locket==0.2.1
Markdown==3.3.6
markdown-it-py==1.1.0
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1635833572614/work
matplotlib==3.5.1
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1631080358261/work
mccabe==0.6.1
mdit-py-plugins==0.3.0
mistune==0.8.4
mock @ file:///home/conda/feedstock_root/build_artifacts/mock_1635819534735/work
msgpack==1.0.3
murmurhash @ file:///home/conda/feedstock_root/build_artifacts/murmurhash_1636019583024/work
mypy-extensions @ file:///home/conda/feedstock_root/build_artifacts/mypy_extensions_1635839660470/work
nbclient==0.5.11
nbconvert==6.4.2
nbformat==5.1.3
nest-asyncio==1.5.4
networkx==2.6.3
ninja==1.11.1
nltk==3.7
notebook==6.4.1
numba @ file:///home/conda/feedstock_root/build_artifacts/numba_1623568544775/work
numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1643958805350/work
nvidia-dali-cuda110==1.10.0
nvidia-pyindex==1.0.9
nvtx==0.2.4
oauthlib==3.2.0
onnx @ file:///opt/pytorch/pytorch/third_party/onnx
opt-einsum==3.3.0
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1637239678211/work
pandas==1.3.5
pandocfilters==1.5.0
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
partd==1.2.0
pathspec @ file:///home/conda/feedstock_root/build_artifacts/pathspec_1626613672358/work
pathy @ file:///home/conda/feedstock_root/build_artifacts/pathy_1635227809952/work
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1602535608087/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
Pillow @ file:///tmp/pillow-simd
pkginfo @ file:///home/conda/feedstock_root/build_artifacts/pkginfo_1638813452194/work
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1644222440849/work
pluggy==1.0.0
polygraphy==0.33.0
pooch==1.6.0
preshed @ file:///home/conda/feedstock_root/build_artifacts/preshed_1636077712344/work
prettytable==3.1.0
prometheus-client==0.13.1
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1643362612956/work
protobuf==3.19.4
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1640887117172/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
py==1.11.0
py-cpuinfo==9.0.0
pyarrow @ file:///rapids/pyarrow-5.0.0-cp38-cp38-linux_x86_64.whl
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybind11==2.9.1
pycocotools @ git+https://github.com/nvidia/cocoapi.git@142b17a358fdb5a31f9d5153d7a9f3f1cd385178#subdirectory=PythonAPI
pycodestyle==2.5.0
pycosat @ file:///home/conda/feedstock_root/build_artifacts/pycosat_1636020377748/work
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
pydantic @ file:///home/conda/feedstock_root/build_artifacts/pydantic_1636021149719/work
pydot==1.4.2
pyflakes==2.1.1
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1641580240686/work
pynvml==11.4.1
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1633192417276/work
pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1642753572664/work
pyrsistent==0.18.1
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1635862404924/work
pytest==6.2.5
pytest-cov==3.0.0
pytest-pythonpath==0.7.4
python-dateutil==2.8.2
python-hostlist==1.21
python-nvd3==0.15.0
python-slugify==5.0.2
pytorch-quantization==2.1.2
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1633452062248/work
PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1636139793187/work
pyzmq==22.3.0
regex==2020.1.8
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1637771257551/work
requests-oauthlib==1.3.1
resampy==0.2.2
revtok @ git+git://github.com/jekbradbury/revtok.git@f1998b72a941d1e5f9578a66dc1c20b01913caab
rmm @ file:///rapids/rmm-21.12.0a0%2B31.g0acbd51-cp38-cp38-linux_x86_64.whl
rsa==4.8
ruamel-yaml-conda @ file:///home/conda/feedstock_root/build_artifacts/ruamel_yaml_1636009157217/work
s3transfer==0.3.7
sacremoses==0.0.47
scikit-learn @ file:///rapids/scikit_learn-0.24.0-cp38-cp38-manylinux2010_x86_64.whl
scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy_1619561901336/work
Send2Trash==1.8.0
sentencepiece==0.1.97
shellingham @ file:///home/conda/feedstock_root/build_artifacts/shellingham_1612179560728/work
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
smart-open @ file:///home/conda/feedstock_root/build_artifacts/smart_open_1630238320325/work
snowballstemmer==2.2.0
sortedcontainers==2.4.0
SoundFile==0.10.3.post1
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1638550740809/work
spacy @ file:///home/conda/feedstock_root/build_artifacts/spacy_1642167419405/work
spacy-legacy @ file:///home/conda/feedstock_root/build_artifacts/spacy-legacy_1625687473390/work
spacy-loggers @ file:///home/conda/feedstock_root/build_artifacts/spacy-loggers_1634809367310/work
Sphinx==4.4.0
sphinx-glpi-theme==0.3
sphinx-rtd-theme==1.0.0
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
srsly @ file:///home/conda/feedstock_root/build_artifacts/srsly_1638879568141/work
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1642255706390/work
tabulate==0.8.9
tblib==1.7.0
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-estimator==2.11.0
tensorflow-io-gcs-filesystem==0.30.0
tensorrt @ file:///workspace/TensorRT-8.2.3.0/python/tensorrt-8.2.3.0-cp38-none-linux_x86_64.whl
termcolor==2.2.0
terminado==0.13.1
testpath==0.5.0
text-unidecode==1.3
thinc @ file:///home/conda/feedstock_root/build_artifacts/thinc_1638980259098/work
threadpoolctl==3.1.0
timm==0.3.2
tokenizers==0.13.2
toml==0.10.2
tomli @ file:///home/conda/feedstock_root/build_artifacts/tomli_1644342247877/work
toolz==0.11.2
torch==1.11.0a0+17540c5
torch-tensorrt @ file:///opt/pytorch/torch_tensorrt/py/dist/torch_tensorrt-1.1.0a0-cp38-cp38-linux_x86_64.whl
torchtext @ file:///opt/pytorch/text
torchvision @ file:///opt/pytorch/vision
tornado==6.1
tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1632160078689/work
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1635260543454/work
transformers==4.26.1
treelite @ file:///rapids/treelite-2.1.0-py3-none-manylinux2014_x86_64.whl
treelite-runtime @ file:///rapids/treelite_runtime-2.1.0-py3-none-manylinux2014_x86_64.whl
triton==1.0.0
typed-ast @ file:///home/conda/feedstock_root/build_artifacts/typed-ast_1643045767561/work
typer @ file:///home/conda/feedstock_root/build_artifacts/typer_1630326630489/work
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1638334978229/work
ucx-py @ file:///rapids/ucx_py-0.21.0a0%2B37.gbfa0450-cp38-cp38-linux_x86_64.whl
uff @ file:///workspace/TensorRT-8.2.3.0/uff/uff-0.6.9-py2.py3-none-any.whl
urllib3==1.25.11
wasabi @ file:///home/conda/feedstock_root/build_artifacts/wasabi_1638865582891/work
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1600965781394/work
webencodings==0.5.1
Werkzeug==2.0.3
wget==3.2
wrapt==1.14.1
xgboost @ file:///rapids/xgboost-1.5.0-cp38-cp38-linux_x86_64.whl
zict==2.0.0
zipp==3.7.0
@TatianaShavrina
Copy link
Collaborator

Hey @LEv145 , thank you for bringing that up!

The num_beams parameter refers to the beam search decoding strategy for the model: see HuggingFace explanation
Try to pass in to the generation function as an argument, or stick to sampling or greedy generation

The parameters can be found in the generate function in xl_wrapper script

@LEv145
Copy link
Author

LEv145 commented Feb 21, 2023

Hey @LEv145 , thank you for bringing that up!

The num_beams parameter refers to the beam search decoding strategy for the model: see HuggingFace explanation Try to pass in to the generation function as an argument, or stick to sampling or greedy generation

The parameters can be found in the generate function in xl_wrapper script

Thanks it works!
But there is a problem when processing the result:

Load checkpoint from /mnt/store/models/rugpt3xl/mp_rank_00_model_states.pt
Model Loaded
Traceback (most recent call last):
  File "/mnt/store/tests/test_rugpt3xl.py", line 29, in <module>
    main()
  File "/mnt/store/tests/test_rugpt3xl.py", line 19, in main
    result = gpt.generate(
  File "/opt/ru-gpts/src/xl_wrapper.py", line 244, in generate
    return list(map(self.tokenizer.decode, res.tolist()))
AttributeError: 'NoneType' object has no attribute 'tolist'
Code
import os
import sys

sys.path.append("/opt/ru-gpts/")
os.environ["USE_DEEPSPEED"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "5000"
from src.xl_wrapper import RuGPT3XL


def main():
    gpt = RuGPT3XL.from_pretrained(
        "sberbank-ai/rugpt3xl",
        weights_path="/mnt/store/models/rugpt3xl/mp_rank_00_model_states.pt",
        seq_len=512,
    )
    result = gpt.generate(
        "Кто был президентом США в 2020? ",
        max_length=50,
        num_beams=5,
        early_stopping=True,
    )
    print(result)


if __name__ == "__main__":
    main()

@sh0tcall3r
Copy link

I have the same problem while generating text with the model.
Firstly it requires num_beams and after it's set, AttributeError: 'NoneType' object has no attribute 'tolist' appears like in the post above.
Please fix or provide comments on how to resolve it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants