forked from probmods/webppl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
interpolating-cache.wppl
103 lines (78 loc) · 2.52 KB
/
interpolating-cache.wppl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
// Plans:
// 1. Andreas and Jason: learn about variational inference in webppl
// 2. Jason: Experiment with code below and turn it into something that actually runs
var N = 11;
var threshold = 0.03;
var mean = function(erp) {
return expectation(erp, function(x){ return x; });
};
var slowBinomial = function(p) {
return Enumerate(function(){
var xs = repeat(N, function(){ return flip(p); });
return sum(xs);
});
};
// this is the function we want to speed up by caching
var slowBinomialMean = function(p) {
return mean(slowBinomial(p));
};
// versions:
// 1. plain cache (just store in db; look up)
// 2. naive similarity cache (look up closest value in db; if distance below threshold return, otherwise re-eval)
// 3. principled inference version
globalStore.distOnFuncs = MCMC(function(){
var bias = uniform(-5, 5);
var theta = uniform(-5, 5);
return function(x) {
var y = x * theta + bias;
return gaussian(y, .1);
};
});
var updateDistOnFuncs = function(input, output) {
globalStore.distOnFuncs = Variational(function() {
var f = sample(globalStore.distOnFuncs);
condition(f(input) == output);
return f;
});
};
var variance = function(erp) {
var m = mean(erp);
return expectation(erp, function(x) { return Math.pow(x - m, 2); });
}
var surrogate = function(func, arg) {
var distOnReturnVals = MCMC(
function() {
var f = sample(globalStore.distOnFuncs);
return f(arg);
});
var returnValMean = mean(distOnReturnVals);
var returnValVariance = variance(distOnReturnVals);
if (returnValVariance > threshold) {
var trueReturnVal = func(arg);
updateDistOnFuncs(arg, trueReturnVal);
return trueReturnVal;
} else {
return returnValMean;
}
}
// cache : func -> cachedFunc
var makeCachedFunction = function (func, surrogateFunc) {
return function (arg) {
return surrogateFunc(func, arg);
};
}
var cachedBinomialMean = makeCachedFunction(slowBinomialMean, surrogate);
var inferBinomialP = function(meanFunc) {
var p = uniform(0, 1);
var m = meanFunc(p);
condition(m < 3); // we know that the mean is < 3
return p;
};
var slowBinomialPDist = Rejection(function() { return inferBinomialP(slowBinomialMean); }, 30);
var cachedBinomialPDist = Rejection(function() { return inferBinomialP(cachedBinomialMean); }, 30);
console.log("Slow mean p: " + mean(slowBinomialPDist));
console.log("Cached mean p: " + mean(cachedBinomialPDist));
console.log("Slow p:");
slowBinomialPDist.print();
console.log("Cached p:");
cachedBinomialPDist.print();