1 /**
2   Useful math functions and range-based statistic computations.
3 
4   If you need real statistics, consider using the $(WEB github.com/dsimcha/dstats,Dstats) library.
5  */
6 module gfm.math.funcs;
7 
8 import std.math,
9        std.traits,
10        std.range,
11        std.math;
12 
13 version( D_InlineAsm_X86 )
14 {
15     version = AsmX86;
16 }
17 else version( D_InlineAsm_X86_64 )
18 {
19     version = AsmX86;
20 }
21 
22 /// Convert from radians to degrees.
23 @nogc T degrees(T)(T x) pure nothrow if (!isIntegral!T)
24 {
25     return x * (180 / PI);
26 }
27 
28 /// Convert from degrees to radians.
29 @nogc T radians(T)(T x) pure nothrow if (!isIntegral!T)
30 {
31     return x * (PI / 180);
32 }
33 
34 /// Linear interpolation, akin to GLSL's mix.
35 @nogc S lerp(S, T)(S a, S b, T t) pure nothrow
36   if (is(typeof(t * b + (1 - t) * a) : S))
37 {
38     return t * b + (1 - t) * a;
39 }
40 
41 /// Clamp x in [min, max], akin to GLSL's clamp.
42 @nogc T clamp(T)(T x, T min, T max) pure nothrow
43 {
44     if (x < min)
45         return min;
46     else if (x > max)
47         return max;
48     else
49         return x;
50 }
51 
52 /// Integer truncation.
53 @nogc long ltrunc(real x) nothrow // may be pure but trunc isn't pure
54 {
55     return cast(long)(trunc(x));
56 }
57 
58 /// Integer flooring.
59 @nogc long lfloor(real x) nothrow // may be pure but floor isn't pure
60 {
61     return cast(long)(floor(x));
62 }
63 
64 /// Returns: Fractional part of x.
65 @nogc T fract(T)(real x) nothrow
66 {
67     return x - lfloor(x);
68 }
69 
70 /// Safe asin: input clamped to [-1, 1]
71 @nogc T safeAsin(T)(T x) pure nothrow
72 {
73     return asin(clamp!T(x, -1, 1));
74 }
75 
76 /// Safe acos: input clamped to [-1, 1]
77 @nogc T safeAcos(T)(T x) pure nothrow
78 {
79     return acos(clamp!T(x, -1, 1));
80 }
81 
82 /// Same as GLSL step function.
83 /// 0.0 is returned if x < edge, and 1.0 is returned otherwise.
84 @nogc T step(T)(T edge, T x) pure nothrow
85 {
86     return (x < edge) ? 0 : 1;
87 }
88 
89 /// Same as GLSL smoothstep function.
90 /// See: http://en.wikipedia.org/wiki/Smoothstep
91 @nogc T smoothStep(T)(T a, T b, T t) pure nothrow
92 {
93     if (t <= a)
94         return 0;
95     else if (t >= b)
96         return 1;
97     else
98     {
99         T x = (t - a) / (b - a);
100         return x * x * (3 - 2 * x);
101     }
102 }
103 
104 /// Returns: true of i is a power of 2.
105 @nogc bool isPowerOf2(T)(T i) pure nothrow if (isIntegral!T)
106 {
107     assert(i >= 0);
108     return (i != 0) && ((i & (i - 1)) == 0);
109 }
110 
111 /// Integer log2
112 /// TODO: use bt intrinsics
113 @nogc int ilog2(T)(T i) nothrow if (isIntegral!T)
114 {
115     assert(i > 0);
116     assert(isPowerOf2(i));
117     int result = 0;
118     while (i > 1)
119     {
120         i = i / 2;
121         result = result + 1;
122     }
123     return result;
124 }
125 
126 /// Computes next power of 2.
127 @nogc int nextPowerOf2(int i) pure nothrow
128 {
129     int v = i - 1;
130     v |= v >> 1;
131     v |= v >> 2;
132     v |= v >> 4;
133     v |= v >> 8;
134     v |= v >> 16;
135     v++;
136     assert(isPowerOf2(v));
137     return v;
138 }
139 
140 /// Computes next power of 2.
141 @nogc long nextPowerOf2(long i) pure nothrow
142 {
143     long v = i - 1;
144     v |= v >> 1;
145     v |= v >> 2;
146     v |= v >> 4;
147     v |= v >> 8;
148     v |= v >> 16;
149     v |= v >> 32;
150     v++;
151     assert(isPowerOf2(v));
152     return v;
153 }
154 
155 /// Computes sin(x)/x accurately.
156 /// See_also: $(WEB www.plunk.org/~hatch/rightway.php)
157 @nogc T sinOverX(T)(T x) pure nothrow
158 {
159     if (1 + x * x == 1)
160         return 1;
161     else
162         return sin(x) / x;
163 }
164 
165 
166 /// Signed integer modulo a/b where the remainder is guaranteed to be in [0..b[,
167 /// even if a is negative. Only support positive dividers.
168 @nogc T moduloWrap(T)(T a, T b) pure nothrow if (isSigned!T)
169 in
170 {
171     assert(b > 0);
172 }
173 body
174 {
175     if (a >= 0)
176         a = a % b;
177     else
178     {
179         auto rem = a % b;
180         x = (rem == 0) ? 0 : (-rem + b);
181     }
182 
183     assert(x >= 0 && x < b);
184     return x;
185 }
186 
187 unittest
188 {
189     assert(nextPowerOf2(13) == 16);
190 }
191 
192 /**
193  * Find the root of a linear polynomial a + b x = 0
194  * Returns: Number of roots.
195  */
196 @nogc int solveLinear(T)(T a, T b, out T root) pure nothrow if (isFloatingPoint!T)
197 {
198     if (b == 0)
199     {
200         return 0;
201     }
202     else
203     {
204         root = -a / b;
205         return 1;
206     }
207 }
208 
209 
210 /**
211  * Finds the root roots of a quadratic polynomial a + b x + c x^2 = 0
212  * Params:
213  *     a = Coefficient.
214  *     b = Coefficient.
215  *     c = Coefficient.
216  *     outRoots = array of root results, should have room for at least 2 elements.
217  * Returns: Number of roots in outRoots.
218  */
219 @nogc int solveQuadratic(T)(T a, T b, T c, T[] outRoots) pure nothrow if (isFloatingPoint!T)
220 {
221     assert(outRoots.length >= 2);
222     if (c == 0)
223         return solveLinear(a, b, outRoots[0]);
224 
225     T delta = b * b - 4 * a * c;
226     if (delta < 0.0 )
227         return 0;
228 
229     delta = sqrt(delta);
230     T oneOver2a = 0.5 / a;
231 
232     outRoots[0] = oneOver2a * (-b - delta);
233     outRoots[1] = oneOver2a * (-b + delta);
234     return 2;
235 }
236 
237 
238 /**
239  * Finds the roots of a cubic polynomial  a + b x + c x^2 + d x^3 = 0
240  * Params:
241  *     a = Coefficient.
242  *     b = Coefficient.
243  *     c = Coefficient.
244  *     d = Coefficient.
245  *     outRoots = array of root results, should have room for at least 2 elements.
246  * Returns: Number of roots in outRoots.
247  * See_also: $(WEB www.codeguru.com/forum/archive/index.php/t-265551.html)
248  */
249 @nogc int solveCubic(T)(T a, T b, T c, T d, T[] outRoots) pure nothrow if (isFloatingPoint!T)
250 {
251     assert(outRoots.length >= 3);
252     if (d == 0)
253         return solveQuadratic(a, b, c, outRoots);
254 
255     // adjust coefficients
256     T a1 = c / d,
257       a2 = b / d,
258       a3 = a / d;
259 
260     T Q = (a1 * a1 - 3 * a2) / 9,
261       R = (2 * a1 * a1 * a1 - 9 * a1 * a2 + 27 * a3) / 54;
262 
263     T Qcubed = Q * Q * Q;
264     T d2 = Qcubed - R * R;
265 
266     if (d2 >= 0)
267     {
268         // 3 real roots
269         if (Q < 0.0)
270             return 0;
271         T P = R / sqrt(Qcubed);
272 
273         assert(-1 <= P && P <= 1);
274         T theta = acos(P);
275         T sqrtQ = sqrt(Q);
276 
277         outRoots[0] = -2 * sqrtQ * cos(theta / 3) - a1 / 3;
278         outRoots[1] = -2 * sqrtQ * cos((theta + 2 * PI) / 3) - a1 / 3;
279         outRoots[2] = -2 * sqrtQ * cos((theta + 4 * PI) / 3) - a1 / 3;
280         return 3;
281     }
282     else
283     {
284         // 1 real root
285         T e = (sqrt(-d) + abs(R)) ^^ cast(T)(1.0 / 3.0);
286         if (R > 0)
287             e = -e;
288         outRoots[0] = e + Q / e - a1 / 3.0;
289         return 1;
290     }
291 }
292 
293 /**
294  * Returns the roots of a quartic polynomial  a + b x + c x^2 + d x^3 + e x^4 = 0
295  *
296  * Returns number of roots. roots slice should have room for up to 4 elements.
297  * Bugs: doesn't pass unit-test!
298  * See_also: $(WEB mathworld.wolfram.com/QuarticEquation.html)
299  */
300 @nogc int solveQuartic(T)(T a, T b, T c, T d, T e, T[] roots) pure nothrow if (isFloatingPoint!T)
301 {
302     assert(roots.length >= 4);
303 
304     if (e == 0)
305         return solveCubic(a, b, c, d, roots);
306 
307     // Adjust coefficients
308     T a0 = a / e,
309       a1 = b / e,
310       a2 = c / e,
311       a3 = d / e;
312 
313     // Find a root for the following cubic equation:
314     //     y^3 - a2 y^2 + (a1 a3 - 4 a0) y + (4 a2 a0 - a1 ^2 - a3^2 a0) = 0
315     // aka Resolvent cubic
316     T b0 = 4 * a2 * a0 - a1 * a1 - a3 * a3 * a0;
317     T b1 = a1 * a3 - 4 * a0;
318     T b2 = -a2;
319     T[3] resolventCubicRoots;
320     int numRoots = solveCubic!T(b0, b1, b2, 1, resolventCubicRoots[]);
321     assert(numRoots == 3);
322     T y = resolventCubicRoots[0];
323     if (y < resolventCubicRoots[1]) y = resolventCubicRoots[1];
324     if (y < resolventCubicRoots[2]) y = resolventCubicRoots[2];
325 
326     // Compute R, D & E
327     T R = 0.25f * a3 * a3 - a2 + y;
328     if (R < 0.0)
329         return 0;
330     R = sqrt(R);
331 
332     T D = void,
333       E = void;
334     if (R == 0)
335     {
336         T d1 = 0.75f * a3 * a3 - 2 * a2;
337         T d2 = 2 * sqrt(y * y - 4 * a0);
338         D = sqrt(d1 + d2) * 0.5f;
339         E = sqrt(d1 - d2) * 0.5f;
340     }
341     else
342     {
343         T Rsquare = R * R;
344         T Rrec = 1 / R;
345         T d1 =  0.75f * a3 * a3 - Rsquare - 2 * a2;
346         T d2 = 0.25f * Rrec * (4 * a3 * a2 - 8 * a1 - a3 * a3 * a3);
347         D = sqrt(d1 + d2) * 0.5f;
348         E = sqrt(d1 - d2) * 0.5f;
349     }
350 
351     // Compute the 4 roots
352     a3 *= -0.25f;
353     R *= 0.5f;
354 
355     roots[0] = a3 + R + D;
356     roots[1] = a3 + R - D;
357     roots[2] = a3 - R + E;
358     roots[3] = a3 - R - E;
359     return 4;
360 }
361 
362 
363 unittest
364 {
365     bool arrayContainsRoot(double[] arr, double root)
366     {
367         foreach(e; arr)
368             if (abs(e - root) < 1e-7)
369                 return true;
370         return false;
371     }
372 
373     // test quadratic
374     {
375         double[3] roots;
376         int numRoots = solveCubic!double(-2, -3 / 2.0, 3 / 4.0, 1 / 4.0, roots[]);
377         assert(numRoots == 3);
378         assert(arrayContainsRoot(roots[], -4));
379         assert(arrayContainsRoot(roots[], -1));
380         assert(arrayContainsRoot(roots[], 2));
381     }
382 
383     // test quartic
384     {
385         double[4] roots;
386         int numRoots = solveQuartic!double(0, -2, -1, 2, 1, roots[]);
387 
388         assert(numRoots == 4);
389         assert(arrayContainsRoot(roots[], -2));
390         assert(arrayContainsRoot(roots[], -1));
391         assert(arrayContainsRoot(roots[], 0));
392         assert(arrayContainsRoot(roots[], 1));
393     }
394 }
395 
396 /// Arithmetic mean.
397 double average(R)(R r) if (isInputRange!R)
398 {
399     if (r.empty)
400         return double.nan;
401 
402     typeof(r.front()) sum = 0;
403     long count = 0;
404     foreach(e; r)
405     {
406         sum += e;
407         ++count;
408     }
409     return sum / count;
410 }
411 
412 /// Minimum of a range.
413 double minElement(R)(R r) if (isInputRange!R)
414 {
415     // do like Javascript for an empty range
416     if (r.empty)
417         return double.infinity;
418 
419     return minmax!("<", R)(r);
420 }
421 
422 /// Maximum of a range.
423 double maxElement(R)(R r) if (isInputRange!R)
424 {
425     // do like Javascript for an empty range
426     if (r.empty)
427         return -double.infinity;
428 
429     return minmax!(">", R)(r);
430 }
431 
432 /// Variance of a range.
433 double variance(R)(R r) if (isForwardRange!R)
434 {
435     if (r.empty)
436         return double.nan;
437 
438     auto avg = average(r.save); // getting the average
439 
440     typeof(avg) sum = 0;
441     long count = 0;
442     foreach(e; r)
443     {
444         sum += (e - avg) ^^ 2;
445         ++count;
446     }
447     if (count <= 1)
448         return 0.0;
449     else
450         return (sum / (count - 1.0)); // using sample std deviation as estimator
451 }
452 
453 /// Standard deviation of a range.
454 double standardDeviation(R)(R r) if (isForwardRange!R)
455 {
456     return sqrt(variance(r));
457 }
458 
459 private
460 {
461     typeof(R.front()) minmax(string op, R)(R r) if (isInputRange!R)
462     {
463         assert(!r.empty);
464         auto best = r.front();
465         r.popFront();
466         foreach(e; r)
467         {
468             mixin("if (e " ~ op ~ " best) best = e;");
469         }
470         return best;
471     }
472 }
473 
474 /// SSE approximation of reciprocal square root.
475 @nogc T inverseSqrt(T)(T x) pure nothrow if (isFloatingPoint!T)
476 {
477     version(AsmX86)
478     {
479         static if (is(T == float))
480         {
481             float result;
482 
483             asm pure nothrow @nogc 
484             {
485                 movss XMM0, x; 
486                 rsqrtss XMM0, XMM0; 
487                 movss result, XMM0; 
488             }
489             return result;
490         }
491         else
492             return 1 / sqrt(x);
493     }
494     else
495         return 1 / sqrt(x);
496 }
497 
498 unittest
499 {
500     assert(abs( inverseSqrt!float(1) - 1) < 1e-3 );
501     assert(abs( inverseSqrt!double(1) - 1) < 1e-3 );
502 }