-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
73 lines (54 loc) · 1.82 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import time, sys
from torch import set_grad_enabled
import lamp as l
from data import generate_data
from helpers import optimize, compute_accuracy, pickle_load, pickle_dump
from plot import plot_results
set_grad_enabled(False)
filename = 'model'
# Load data
num_samples = 1000
train_data, test_data = generate_data(num_samples)
try:
# Load trained model from pickle
model = pickle_load(filename)
print(f'Model loaded from file {filename}.pkl')
except FileNotFoundError:
# Train the model
input_dim = 2
output_dim = 1
nb_hidden = 25
model = l.Sequential(
l.Linear(input_dim, nb_hidden),
l.Tanh(),
l.Linear(nb_hidden),
l.Tanh(),
l.Linear(nb_hidden),
l.Tanh(),
l.Linear(nb_hidden, output_dim),
l.Sigmoid()
)
epochs = 100
batch_size = 50
lr = 0.16
shuffle = False
print(f'Number of samples: {num_samples}')
print(f'Epochs: {epochs}')
print(f'Batch size: {batch_size}')
print(f'Learning rate: {lr}')
print(f'Shuffle before each epoch: {shuffle}')
print()
for i in reversed(range(1, 4)):
sys.stdout.write("\rStarting in %i" % i)
sys.stdout.flush()
time.sleep(1)
sys.stdout.write('\r##################################\n')
optimize(model, train_data, test_data, epochs=epochs, batch_size=batch_size, lr=lr, shuffle=shuffle)
print('##################################')
# uncomment the following lines to enable model pickling
# pickle_dump(filename, model)
# print(f'Trained model pickled to {filename}.pkl')
accuracy, correct_class, _ = compute_accuracy(model, *test_data)
print('Test accuracy : {:.1f}%'.format(round(accuracy.item(), 3) * 100))
print('Plotting data and results...')
plot_results(train_data, test_data, correct_class)