-
Notifications
You must be signed in to change notification settings - Fork 0
/
unit.go
143 lines (125 loc) · 2.29 KB
/
unit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package units
import (
"fmt"
"math"
"reflect"
"sort"
"strings"
)
type (
dims map[string]int
unit struct {
scale float64
dims dims
}
Unit interface {
String() string
IsScalar() bool
Equal(Unit) bool
Subs(map[string]Unit) Unit
Validate(map[string]Unit) error
Invert() Unit
Multiply(Unit) Unit
Divide(Unit) Unit
Scale() float64
}
)
func Scalar(scale float64) Unit {
return makeUnit(scale, dims{})
}
func NewUnit(name string, dim int) Unit {
if name == "" || dim == 0 {
return Scalar(1)
}
return makeUnit(1.0, dims{name: dim})
}
func makeUnit(scale float64, dms dims) unit {
if dms == nil {
dms = dims{}
}
if scale == 0.0 {
scale = 1.0
}
return unit{
scale: scale,
dims: dms,
}
}
func (u unit) IsScalar() bool {
return len(u.dims) == 0
}
func (u unit) Subs(us map[string]Unit) Unit {
out := makeUnit(u.scale, nil)
for k1, v1 := range u.dims {
v2, ok := us[k1]
if !ok {
out.dims[k1] += v1
continue
}
u2 := v2.(unit)
out.scale *= math.Pow(u2.scale, float64(v1))
for k3, v3 := range u2.dims {
out.dims[k3] += v1 * v3
}
}
if !u.Equal(out) {
// some substitutions were made so reprocess
return out.Subs(us)
}
return out
}
func (u unit) Validate(us map[string]Unit) error {
for k1 := range u.dims {
_, ok := us[k1]
if !ok {
return fmt.Errorf("the unit '%s' is not defined", k1)
}
}
return nil
}
func (u unit) Equal(u2 Unit) bool {
return u.Scale() == u2.Scale() && reflect.DeepEqual(u, u2)
}
func (u unit) String() string {
var ss []string
for k, v := range u.dims {
if v == 1 {
ss = append(ss, k)
continue
}
ss = append(ss, fmt.Sprintf("%s^%d", k, v))
}
sort.Strings(ss)
if u.scale == 1.0 {
return strings.Join(ss, "*")
}
return fmt.Sprintf("%.17e*", u.scale) + strings.Join(ss, "*")
}
func (u unit) Invert() Unit {
out := makeUnit(1.0/u.scale, nil)
for k, v := range u.dims {
out.dims[k] = -v
}
return out
}
func (u unit) Multiply(u2 Unit) Unit {
out := makeUnit(u.scale*u2.Scale(), nil)
for k, v := range u.dims {
out.dims[k] = v
}
for k, v := range u2.(unit).dims {
val := out.dims[k] + v
if val == 0 {
delete(out.dims, k)
continue
}
out.dims[k] = val
}
return out
}
func (u unit) Divide(ut Unit) Unit {
return u.Multiply(ut.Invert())
}
func (u unit) Scale() float64 {
return u.scale
}