-
Notifications
You must be signed in to change notification settings - Fork 1
/
rw.py
278 lines (227 loc) · 9 KB
/
rw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
from enum import Enum
import urllib.request
import graph
import pickle
import argparse
import sys
import tempfile
class Tokenization(Enum):
"""
word: Interpret the input as UTF-8 and split the input at any
white-space characters and use the strings between the white-space
as tokens. So "a b" would be ["a", "b"] as would "a\n b".
character: Interpret the input as UTF-8 and use the characters as
tokens.
byte: Read the input as raw bytes and use individual bytes as the
tokens.
none: Do not tokenize. The input must be an iterable.
"""
word = 1
character = 2
byte = 3
none = 4
class RandomWriter(object):
"""
A Markov chain based random data generator.
"""
def __init__(self, level, tokenization=Tokenization.none):
"""
Initialize a random writer.
Args:
level: The context length or "level" of model to build.
tokenization: A value from Tokenization. This specifies how
the data should be tokenized.
The value given for tokenization will affect what types of
data are supported.
"""
if level < 0:
raise ValueError("The level of analysis must be >= 0.")
if tokenization not in Tokenization:
raise ValueError("You did not provide a valid tokenization mode.")
self._mode = tokenization
self._level = level
self._graph = None
def generate(self):
"""
Yield random tokens using the model, infinitely.
"""
if self._graph is None:
raise ValueError("The RandomWriter must be trained before it can"
"generate tokens.")
while True:
yield self._graph.get_random_token()
def generate_file(self, filename, amount):
"""
Write a file using the model.
Args:
filename: The name of the file to write output to.
amount: The number of tokens to write.
For character or byte tokens this will just output the
tokens one after another. For any other type of token a space
will be added between tokens.
"""
if self._mode is Tokenization.byte:
if not hasattr(filename, 'write'):
with open(filename, mode="wb") as fi:
self.generate_file(fi, amount)
else:
gen = self.generate()
filename.write(bytes(next(gen) for _ in range(amount)))
else:
if not hasattr(filename, 'write'):
with open(filename, mode="w", encoding="utf-8") as fi:
self.generate_file(fi, amount)
else:
for _ in range(amount):
content = str(next(self.generate()))
if self._mode is not Tokenization.character:
content += " "
filename.write(content)
def save_pickle(self, filename_or_file_object):
"""
Write this model out as a Python pickle.
Args:
filename_or_file_object: A filename or file object to write to.
File objects assumed to be opened in binary mode.
"""
if hasattr(filename_or_file_object, 'write'):
pickle.dump(self, filename_or_file_object, pickle.HIGHEST_PROTOCOL)
else:
# Better open the file first
with open(filename_or_file_object, "wb") as fi:
self.save_pickle(fi)
@classmethod
def load_pickle(cls, filename_or_file_object):
"""
Loads a Python pickle and make sure it is in fact a model.
Args:
filename_or_file_object: A filename or file object to load
from.
Return:
A new instance of RandomWriter which contains the loaded
data.
File objects assumed to be opened in binary mode.
"""
try:
data = pickle.load(filename_or_file_object)
if isinstance(data, cls):
return data
else:
# Something bad happened
raise ValueError("A RandomWriter could not be loaded from the"
"file.")
except TypeError:
# Better open the file first
with open(filename_or_file_object, "rb") as fi:
data = pickle.load(fi)
return data
def train_url(self, url):
"""
Compute the probabilities based on the data downloaded from url.
Args:
url: The URL to download.
"""
if self._mode is Tokenization.none:
raise ValueError("This method is only supported if the "
" tokenization mode is not none.")
with urllib.request.urlopen(url) as response:
text = response.read()
if self._mode is not Tokenization.byte:
try:
text = str(text, encoding="utf-8")
except UnicodeDecodeError:
# Can't decode as UTF-8, so just try our best
text = str(text)
self.train_iterable(text)
def train_iterable(self, data):
"""
Compute the probabilities based on the data given.
If the tokenization mode is none, data must be an iterable. If
the tokenization mode is character or word, then data must be
a string. Finally, if the tokenization mode is byte, then data
must be a bytes. If the type is wrong, TypeError raised.
"""
data = self.validate_datatype(data)
if data is None:
raise TypeError("Incorrect data given for tokenization mode.")
self._graph = graph.Graph()
if self._level is 0:
for i in range(len(data)):
state = tuple(data[i:i+1])
self._graph.add_edge(state)
else:
for i in range(len(data) - self._level + 1):
# get a slice of self._level tokens to store in the graph
state = tuple(data[i:i+self._level])
self._graph.add_edge(state)
def validate_datatype(self, data):
"""
Ensures the validity of the given data type with the Tokenization mode,
returning data in the correct form for future iteration or None if
invalid combination of data and mode.
"""
if self._mode is Tokenization.word and isinstance(data, str):
return data.split()
elif (self._mode is Tokenization.character and isinstance(data, str) or
self._mode is Tokenization.byte and isinstance(data, bytes) or
self._mode is Tokenization.none and hasattr(data, '__iter__')):
return data
else:
return None
def train_input(args):
"""
Constructs a RandomWriter using the given level and tokenization.
Then trains on the input file or stdin.
Finally, it pickles itself to the output file or stdout.
"""
if args.character:
tokenization = Tokenization.character
elif args.byte:
tokenization = Tokenization.byte
else:
tokenization = Tokenization.word
rw = RandomWriter(args.level, tokenization)
if args.input is sys.stdin:
data = args.input.read()
rw.train_iterable(data)
else:
rw.train_url(args.input)
rw.save_pickle(args.output)
def generate_output(args):
"""
Constructs a RandomWriter from a pickle and proceeds to output the
given amount of generated tokens.
"""
rw = RandomWriter.load_pickle(args.input)
rw.generate_file(args.output, args.amount)
if __name__ == '__main__':
"""
Handles parsing of command line arguments.
"""
parser = argparse.ArgumentParser(add_help=True)
subparsers = parser.add_subparsers()
# The train argument
parser_train = subparsers.add_parser('train', help="Train a model given "
"input and save to pickle output.")
parser_train.add_argument('--input', default=sys.stdin)
parser_train.add_argument('--output', default=sys.stdout.buffer)
token_group = parser_train.add_mutually_exclusive_group()
token_group.add_argument('--word', action='store_true')
token_group.add_argument('--character', action='store_true')
token_group.add_argument('--byte', action='store_true')
parser_train.add_argument('--level', type=int, default=1)
parser_train.set_defaults(func=train_input)
# The generate argument
parser_generate = subparsers.add_parser('generate', help="Generate an "
"output file.")
parser_generate.add_argument('--input', default=sys.stdin.buffer)
parser_generate.add_argument('--output', default=sys.stdout)
parser_generate.add_argument('--amount', required=True, type=int)
parser_generate.set_defaults(func=generate_output)
# because we are only using subparsers, argparse will not print help
# by default, so do it manually.
if len(sys.argv) == 1:
parser.print_help()
exit(1)
args = parser.parse_args()
args.func(args)