forked from moezzie/gobrain
/
feedforward_test.go
92 lines (73 loc) · 2.36 KB
/
feedforward_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package gobrain
import (
"testing"
"math/rand"
"reflect"
)
func ExampleSimpleFeedForward() {
// set the random seed to 0
rand.Seed(0)
// create the XOR representation patter to train the network
patterns := [][][]float64{
{{0, 0}, {0}},
{{0, 1}, {1}},
{{1, 0}, {1}},
{{1, 1}, {0}},
}
// instantiate the Feed Forward
ff := &FeedForward{}
// initialize the Neural Network;
// the networks structure will contain:
// 2 inputs, 2 hidden nodes and 1 output.
ff.Init(2, 2, 1)
// train the network using the XOR patterns
// the training will run for 1000 epochs
// the learning rate is set to 0.6 and the momentum factor to 0.4
// use true in the last parameter to receive reports about the learning error
ff.Train(patterns, 1000, 0.6, 0.4, false)
// testing the network
ff.Test(patterns)
// predicting a value
inputs := []float64{1, 1}
ff.Update(inputs)
// Output:
// [0 0] -> [0.05750394570844524] : [0]
// [0 1] -> [0.9301006350712102] : [1]
// [1 0] -> [0.927809966227284] : [1]
// [1 1] -> [0.09740879532462095] : [0]
}
func TestSerialize(t *testing.T) {
// set the random seed to 0
rand.Seed(0)
// create the XOR representation patter to train the network
patterns := [][][]float64{
{{0, 0}, {0}},
{{0, 1}, {1}},
{{1, 0}, {1}},
{{1, 1}, {0}},
}
// instantiate the Feed Forward
ff := &FeedForward{}
// initialize the Neural Network;
// the networks structure will contain:
// 2 inputs, 2 hidden nodes and 1 output.
ff.Init(2, 2, 1)
// train the network using the XOR patterns
// the training will run for 1000 epochs
// the learning rate is set to 0.6 and the momentum factor to 0.4
// use true in the last parameter to receive reports about the learning error
ff.Train(patterns, 10, 0.6, 0.4, false)
// Serialize the newly trained network to json
serialized, err := ff.Serialize()
if err != nil {
t.Error("Serialization failed")
}
// Create a new network and load the serialized data into it
nn := &FeedForward{}
nn.Load(serialized)
// Compare the loaded network to the original one
// Make sure all the values are the same
if !reflect.DeepEqual(ff, nn) {
t.Error("Network was not loaded correctly")
}
}