-
Notifications
You must be signed in to change notification settings - Fork 3
/
interactive.py
97 lines (78 loc) · 3.91 KB
/
interactive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import code
import json
import logging
import os
import prettytable
import time
import torch
import torch.nn as nn
from torch.serialization import default_restore_location
from tqdm import tqdm
from reader import models, utils
from reader.data.dictionary import Dictionary
from reader.data.dataset import ReadingDataset, BatchSampler
from reader.data.tokenizer import Tokenizer, SpacyTokenizer, CoreNLPTokenizer
def get_args():
parser = argparse.ArgumentParser('Question Answering - Interactive Console')
parser.add_argument('--seed', default=42, type=int, help='pseudo random number generator seed')
parser.add_argument('--checkpoint', required=True, help='checkpoint path')
parser.add_argument('--tokenizer', default='corenlp', choices=['spacy', 'corenlp'], help='tokenizer')
return parser.parse_args()
def main(args):
torch.manual_seed(args.seed)
# Load arguments from checkpoint (no need to load pretrained embeddings)
state_dict = torch.load(args.checkpoint, map_location=lambda s, l: default_restore_location(s, 'cpu'))
args = argparse.Namespace(**{**vars(state_dict['args']), **vars(args), 'embed_path': None})
utils.init_logging(args)
# Load dictionary and pretrained model
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
logging.info('Loaded a dictionary with {} words'.format(len(dictionary)))
char_dictionary = Dictionary.load(os.path.join(args.data, 'char_dict.txt'))
logging.info('Loaded a character dictionary with {} words'.format(len(char_dictionary)))
# Load trained model
model = models.build_model(args, dictionary, char_dictionary).cuda().eval()
model.load_state_dict(state_dict['model'])
with open(os.path.join(args.data, 'feature_dict.json')) as file:
feature_dict = json.load(file)
if args.tokenizer == 'spacy':
tokenizer = SpacyTokenizer(annotators=['lemma', 'pos', 'ner'])
elif args.tokenizer == 'corenlp':
tokenizer = CoreNLPTokenizer(annotators=['lemma', 'pos', 'ner'])
def answer(context, question, topk=1):
t0 = time.time()
# Tokenize context and question
context = tokenizer.tokenize(context)
question = tokenizer.tokenize(question)
examples = [{'id': 0, 'question': question, 'context_id': 0, 'answers': {'spans': [], 'texts': []}}]
test_dataset = ReadingDataset(
args, [context], examples, dictionary, char_dictionary, feature_dict=feature_dict, skip_no_answer=False
)
test_loader = torch.utils.data.DataLoader(
test_dataset, num_workers=args.num_workers, collate_fn=test_dataset.collater,
batch_sampler=BatchSampler(test_dataset, args.max_tokens, args.batch_size, shuffle=False, seed=args.seed)
)
# Forward pass
with torch.no_grad():
sample = utils.move_to_cuda(next(iter(test_loader)))
start_scores, end_scores = model(
sample['context_tokens'], sample['question_tokens'],
context_chars=sample['context_chars'],
question_chars=sample['question_chars'],
context_features=sample['context_features']
)
start_preds, end_preds, scores = model.decode(start_scores, end_scores, topk=topk)
# Map predictions to span
table = prettytable.PrettyTable(['Rank', 'Span', 'Score'])
for i, (start_pred, end_pred, score) in enumerate(zip(start_preds[0], end_preds[0], scores[0])):
start_idx = context['offsets'][start_pred][0]
end_idx = context['offsets'][end_pred][1]
text_pred = context['text'][start_idx: end_idx]
table.add_row([i + 1, text_pred, score.item()])
print(table)
print('Time: %.4f' % (time.time() - t0))
# Read-eval-print loop
code.interact(banner='>>> Usage: answer(context, question, topk=5)', local=locals())
if __name__ == '__main__':
args = get_args()
main(args)