forked from Tiiiger/templm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_utils.py
60 lines (55 loc) · 2.12 KB
/
load_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from transformers import AutoTokenizer, AutoConfig, BartForConditionalGeneration
from template_search_bart import TemplateSearchBART
def load_model_and_tokenizer(
model_args,
load_template_model=True,
load_tokenizer=True,
load_no_space_tokenizer=True,
):
if load_tokenizer:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name
if model_args.tokenizer_name
else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
add_prefix_space=True,
)
else:
tokenizer = None
if load_no_space_tokenizer:
no_space_tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name
if model_args.tokenizer_name
else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
add_prefix_space=False,
)
else:
no_space_tokenizer = None
config = AutoConfig.from_pretrained(
model_args.config_name
if model_args.config_name
else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
if load_template_model:
model_class = TemplateSearchBART
else:
model_class = BartForConditionalGeneration
model = model_class.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
return model, tokenizer, no_space_tokenizer