1 /**
2   Useful math functions and range-based statistic computations.
3 
4   If you need real statistics, consider using the $(WEB github.com/dsimcha/dstats,Dstats) library.
5  */
6 module gfm.math.funcs;
7 
8 import std.math,
9 	   std.traits,
10        std.range,
11        std.math;
12 
13 static if( __VERSION__ < 2066 ) private enum nogc = 1;
14 
15 /// Returns: minimum of a and b.
16 @nogc T min(T)(T a, T b) pure nothrow
17 {
18     return a < b ? a : b;
19 }
20 
21 /// Returns: maximum of a and b.
22 @nogc T max(T)(T a, T b) pure nothrow
23 {
24     return a > b ? a : b;
25 }
26 
27 /// Convert from radians to degrees.
28 @nogc T degrees(T)(T x) pure nothrow if (!isIntegral!T)
29 {
30     return x * (180 / PI);
31 }
32 
33 /// Convert from degrees to radians.
34 @nogc T radians(T)(T x) pure nothrow if (!isIntegral!T)
35 {
36     return x * (PI / 180);
37 }
38 
39 /// Linear interpolation, akin to GLSL's mix.
40 @nogc S lerp(S, T)(S a, S b, T t) pure nothrow
41 {
42     return t * b + (1 - t) * a;
43 }
44 
45 /// Clamp x in [min, max], akin to GLSL's clamp.
46 @nogc T clamp(T)(T x, T min, T max) pure nothrow
47 {
48     if (x < min)
49         return min;
50     else if (x > max)
51         return max;
52     else
53         return x;
54 }
55 
56 /// Integer truncation.
57 @nogc long ltrunc(real x) nothrow // may be pure but trunc isn't pure
58 {
59     return cast(long)(trunc(x));
60 }
61 
62 /// Integer flooring.
63 @nogc long lfloor(real x) nothrow // may be pure but floor isn't pure
64 {
65     return cast(long)(floor(x));
66 }
67 
68 /// Returns: Fractional part of x.
69 @nogc T fract(T)(real x) nothrow
70 {
71     return x - lfloor(x);
72 }
73 
74 /// Safe asin: input clamped to [-1, 1]
75 @nogc T safeAsin(T)(T x) pure nothrow
76 {
77     return asin(clamp!T(x, -1, 1));
78 }
79 
80 /// Safe acos: input clamped to [-1, 1]
81 @nogc T safeAcos(T)(T x) pure nothrow
82 {
83     return acos(clamp!T(x, -1, 1));
84 }
85 
86 /// Same as GLSL step function.
87 /// 0.0 is returned if x < edge, and 1.0 is returned otherwise.
88 @nogc T step(T)(T edge, T x) pure nothrow
89 {
90     return (x < edge) ? 0 : 1;
91 }
92 
93 /// Same as GLSL smoothstep function.
94 /// See: http://en.wikipedia.org/wiki/Smoothstep
95 @nogc T smoothStep(T)(T a, T b, T t) pure nothrow
96 {
97     if (t <= a)
98         return 0;
99     else if (t >= b)
100         return 1;
101     else
102     {
103         T x = (t - a) / (b - a);
104         return x * x * (3 - 2 * x);
105     }
106 }
107 
108 /// Returns: true of i is a power of 2.
109 @nogc bool isPowerOf2(T)(T i) pure nothrow if (isIntegral!T)
110 {
111     assert(i >= 0);
112     return (i != 0) && ((i & (i - 1)) == 0);
113 }
114 
115 /// Integer log2
116 /// TODO: use bt intrinsics
117 @nogc int ilog2(T)(T i) nothrow if (isIntegral!T)
118 {
119     assert(i > 0);
120     assert(isPowerOf2(i));
121     int result = 0;
122     while (i > 1)
123     {
124         i = i / 2;
125         result = result + 1;
126     }
127     return result;
128 }
129 
130 /// Computes next power of 2.
131 @nogc int nextPowerOf2(int i) pure nothrow
132 {
133     int v = i - 1;
134     v |= v >> 1;
135     v |= v >> 2;
136     v |= v >> 4;
137     v |= v >> 8;
138     v |= v >> 16;
139     v++;
140     assert(isPowerOf2(v));
141     return v;
142 }
143 
144 /// Computes next power of 2.
145 @nogc long nextPowerOf2(long i) pure nothrow
146 {
147     long v = i - 1;
148     v |= v >> 1;
149     v |= v >> 2;
150     v |= v >> 4;
151     v |= v >> 8;
152     v |= v >> 16;
153     v |= v >> 32;
154     v++;
155     assert(isPowerOf2(v));
156     return v;
157 }
158 
159 /// Computes sin(x)/x accurately.
160 /// See_also: $(WEB www.plunk.org/~hatch/rightway.php)
161 @nogc T sinOverX(T)(T x) pure nothrow
162 {
163     if (1 + x * x == 1)
164         return 1;
165     else
166         return sin(x) / x;
167 }
168 
169 
170 /// Signed integer modulo a/b where the remainder is guaranteed to be in [0..b[,
171 /// even if a is negative. Only support positive dividers.
172 @nogc T moduloWrap(T)(T a, T b) pure nothrow if (isSigned!T)
173 in
174 {
175     assert(b > 0);
176 }
177 body
178 {
179     if (a >= 0)
180         a = a % b;
181     else
182     {
183         auto rem = a % b;
184         x = (rem == 0) ? 0 : (-rem + b);
185     }
186 
187     assert(x >= 0 && x < b);
188     return x;
189 }
190 
191 unittest
192 {
193     assert(nextPowerOf2(13) == 16);
194 }
195 
196 /**
197  * Find the root of a linear polynomial a + b x = 0
198  * Returns: Number of roots.
199  */
200 @nogc int solveLinear(T)(T a, T b, out T root) pure nothrow if (isFloatingPoint!T)
201 {
202     if (b == 0)
203     {
204         return 0;
205     }
206     else
207     {
208         root = -a / b;
209         return 1;
210     }
211 }
212 
213 
214 /**
215  * Finds the root roots of a quadratic polynomial a + b x + c x^2 = 0
216  * Params:
217  *     outRoots = array of root results, should have room for at least 2 elements.
218  * Returns: Number of roots in outRoots.
219  */
220 @nogc int solveQuadratic(T)(T a, T b, T c, T[] outRoots) pure nothrow if (isFloatingPoint!T)
221 {
222     assert(outRoots.length >= 2);
223     if (c == 0)
224         return solveLinear(a, b, outRoots[0]);
225 
226     T delta = b * b - 4 * a * c;
227     if (delta < 0.0 )
228         return 0;
229 
230     delta = sqrt(delta);
231     T oneOver2a = 0.5 / a;
232 
233     outRoots[0] = oneOver2a * (-b - delta);
234     outRoots[1] = oneOver2a * (-b + delta);
235     return 2;
236 }
237 
238 
239 /**
240  * Finds the roots of a cubic polynomial  a + b x + c x^2 + d x^3 = 0
241  * Params:
242  *     outRoots = array of root results, should have room for at least 2 elements.
243  * Returns: Number of roots in outRoots.
244  * See_also: $(WEB www.codeguru.com/forum/archive/index.php/t-265551.html)
245  */
246 @nogc int solveCubic(T)(T a, T b, T c, T d, T[] outRoots) pure nothrow if (isFloatingPoint!T)
247 {
248     assert(outRoots.length >= 3);
249     if (d == 0)
250         return solveQuadratic(a, b, c, outRoots);
251 
252     // adjust coefficients
253     T a1 = c / d,
254       a2 = b / d,
255       a3 = a / d;
256 
257     T Q = (a1 * a1 - 3 * a2) / 9,
258       R = (2 * a1 * a1 * a1 - 9 * a1 * a2 + 27 * a3) / 54;
259 
260     T Qcubed = Q * Q * Q;
261     T d2 = Qcubed - R * R;
262 
263     if (d2 >= 0)
264     {
265         // 3 real roots
266         if (Q < 0.0)
267             return 0;
268         T P = R / sqrt(Qcubed);
269 
270         assert(-1 <= P && P <= 1);
271         T theta = acos(P);
272         T sqrtQ = sqrt(Q);
273 
274         outRoots[0] = -2 * sqrtQ * cos(theta / 3) - a1 / 3;
275         outRoots[1] = -2 * sqrtQ * cos((theta + 2 * PI) / 3) - a1 / 3;
276         outRoots[2] = -2 * sqrtQ * cos((theta + 4 * PI) / 3) - a1 / 3;
277         return 3;
278     }
279     else
280     {
281         // 1 real root
282         T e = (sqrt(-d) + abs(R)) ^^ cast(T)(1.0 / 3.0);
283         if (R > 0)
284             e = -e;
285         outRoots[0] = e + Q / e - a1 / 3.0;
286         return 1;
287     }
288 }
289 
290 /**
291  * Returns the roots of a quartic polynomial  a + b x + c x^2 + d x^3 + e x^4 = 0
292  *
293  * Returns number of roots. roots slice should have room for up to 4 elements.
294  * Bugs: doesn't pass unit-test!
295  * See_also: $(WEB mathworld.wolfram.com/QuarticEquation.html)
296  */
297 @nogc int solveQuartic(T)(T a, T b, T c, T d, T e, T[] roots) pure nothrow if (isFloatingPoint!T)
298 {
299     assert(roots.length >= 4);
300 
301     if (e == 0)
302         return solveCubic(a, b, c, d, roots);
303 
304     // Adjust coefficients
305     T a0 = a / e,
306       a1 = b / e,
307       a2 = c / e,
308       a3 = d / e;
309 
310     // Find a root for the following cubic equation:
311     //     y^3 - a2 y^2 + (a1 a3 - 4 a0) y + (4 a2 a0 - a1 ^2 - a3^2 a0) = 0
312     // aka Resolvent cubic
313     T b0 = 4 * a2 * a0 - a1 * a1 - a3 * a3 * a0;
314     T b1 = a1 * a3 - 4 * a0;
315     T b2 = -a2;
316     T[3] resolventCubicRoots;
317     int numRoots = solveCubic!T(b0, b1, b2, 1, resolventCubicRoots[]);
318     assert(numRoots == 3);
319     T y = resolventCubicRoots[0];
320     if (y < resolventCubicRoots[1]) y = resolventCubicRoots[1];
321     if (y < resolventCubicRoots[2]) y = resolventCubicRoots[2];
322 
323     // Compute R, D & E
324     T R = 0.25f * a3 * a3 - a2 + y;
325     if (R < 0.0)
326         return 0;
327     R = sqrt(R);
328 
329     T D = void,
330       E = void;
331     if (R == 0)
332     {
333         T d1 = 0.75f * a3 * a3 - 2 * a2;
334         T d2 = 2 * sqrt(y * y - 4 * a0);
335         D = sqrt(d1 + d2) * 0.5f;
336         E = sqrt(d1 - d2) * 0.5f;
337     }
338     else
339     {
340         T Rsquare = R * R;
341         T Rrec = 1 / R;
342         T d1 =  0.75f * a3 * a3 - Rsquare - 2 * a2;
343         T d2 = 0.25f * Rrec * (4 * a3 * a2 - 8 * a1 - a3 * a3 * a3);
344         D = sqrt(d1 + d2) * 0.5f;
345         E = sqrt(d1 - d2) * 0.5f;
346     }
347 
348     // Compute the 4 roots
349     a3 *= -0.25f;
350     R *= 0.5f;
351 
352     roots[0] = a3 + R + D;
353     roots[1] = a3 + R - D;
354     roots[2] = a3 - R + E;
355     roots[3] = a3 - R - E;
356     return 4;
357 }
358 
359 
360 unittest
361 {
362     bool arrayContainsRoot(double[] arr, double root)
363     {
364         foreach(e; arr)
365             if (abs(e - root) < 1e-7)
366                 return true;
367         return false;
368     }
369 
370     // test quadratic
371     {
372         double[3] roots;
373         int numRoots = solveCubic!double(-2, -3 / 2.0, 3 / 4.0, 1 / 4.0, roots[]);
374         assert(numRoots == 3);
375         assert(arrayContainsRoot(roots[], -4));
376         assert(arrayContainsRoot(roots[], -1));
377         assert(arrayContainsRoot(roots[], 2));
378     }
379 
380     // test quartic
381     {
382         double[4] roots;
383         int numRoots = solveQuartic!double(0, -2, -1, 2, 1, roots[]);
384 
385         assert(numRoots == 4);
386         assert(arrayContainsRoot(roots[], -2));
387         assert(arrayContainsRoot(roots[], -1));
388         assert(arrayContainsRoot(roots[], 0));
389         assert(arrayContainsRoot(roots[], 1));
390     }
391 }
392 
393 /// Arithmetic mean.
394 double average(R)(R r) if (isInputRange!R)
395 {
396     if (r.empty)
397         return double.nan;
398 
399     typeof(r.front()) sum = 0;
400     long count = 0;
401     foreach(e; r)
402     {
403         sum += e;
404         ++count;
405     }
406     return sum / count;
407 }
408 
409 /// Minimum of a range.
410 double minElement(R)(R r) if (isInputRange!R)
411 {
412     // do like Javascript for an empty range
413     if (r.empty)
414         return double.infinity;
415 
416     return minmax!("<", R)(r);
417 }
418 
419 /// Maximum of a range.
420 double maxElement(R)(R r) if (isInputRange!R)
421 {
422     // do like Javascript for an empty range
423     if (r.empty)
424         return -double.infinity;
425 
426     return minmax!(">", R)(r);
427 }
428 
429 /// Variance of a range.
430 double variance(R)(R r) if (isForwardRange!R)
431 {
432     if (r.empty)
433         return double.nan;
434 
435     auto avg = average(r.save); // getting the average
436 
437     typeof(avg) sum = 0;
438     long count = 0;
439     foreach(e; r)
440     {
441         sum += (e - avg) ^^ 2;
442         ++count;
443     }
444     if (count <= 1)
445         return 0.0;
446     else
447         return (sum / (count - 1.0)); // using sample std deviation as estimator
448 }
449 
450 /// Standard deviation of a range.
451 double standardDeviation(R)(R r) if (isForwardRange!R)
452 {
453     return sqrt(variance(r));
454 }
455 
456 private
457 {
458     typeof(R.front()) minmax(string op, R)(R r) if (isInputRange!R)
459     {
460         assert(!r.empty);
461         auto best = r.front();
462         r.popFront();
463         foreach(e; r)
464         {
465             mixin("if (e " ~ op ~ " best) best = e;");
466         }
467         return best;
468     }
469 }