1 /**
2   Useful math functions and range-based statistic computations.
3 
4   If you need real statistics, consider using the $(WEB github.com/dsimcha/dstats,Dstats) library.
5  */
6 module gfm.math.funcs;
7 
8 import std.math,
9        std.traits,
10        std.range,
11        std.math;
12 
13 static if( __VERSION__ < 2066 ) private enum nogc = 1;
14 
15 version( D_InlineAsm_X86 )
16 {
17     version = AsmX86;
18 }
19 else version( D_InlineAsm_X86_64 )
20 {
21     version = AsmX86;
22 }
23 
24 /// Returns: minimum of a and b.
25 deprecated("Use std.algorithm.min instead") @nogc T min(T)(T a, T b) pure nothrow
26 {
27     return a < b ? a : b;
28 }
29 
30 /// Returns: maximum of a and b.
31 deprecated("Use std.algorithm.max instead") @nogc T max(T)(T a, T b) pure nothrow
32 {
33     return a > b ? a : b;
34 }
35 
36 /// Convert from radians to degrees.
37 @nogc T degrees(T)(T x) pure nothrow if (!isIntegral!T)
38 {
39     return x * (180 / PI);
40 }
41 
42 /// Convert from degrees to radians.
43 @nogc T radians(T)(T x) pure nothrow if (!isIntegral!T)
44 {
45     return x * (PI / 180);
46 }
47 
48 /// Linear interpolation, akin to GLSL's mix.
49 @nogc S lerp(S, T)(S a, S b, T t) pure nothrow
50 {
51     return t * b + (1 - t) * a;
52 }
53 
54 /// Clamp x in [min, max], akin to GLSL's clamp.
55 @nogc T clamp(T)(T x, T min, T max) pure nothrow
56 {
57     if (x < min)
58         return min;
59     else if (x > max)
60         return max;
61     else
62         return x;
63 }
64 
65 /// Integer truncation.
66 @nogc long ltrunc(real x) nothrow // may be pure but trunc isn't pure
67 {
68     return cast(long)(trunc(x));
69 }
70 
71 /// Integer flooring.
72 @nogc long lfloor(real x) nothrow // may be pure but floor isn't pure
73 {
74     return cast(long)(floor(x));
75 }
76 
77 /// Returns: Fractional part of x.
78 @nogc T fract(T)(real x) nothrow
79 {
80     return x - lfloor(x);
81 }
82 
83 /// Safe asin: input clamped to [-1, 1]
84 @nogc T safeAsin(T)(T x) pure nothrow
85 {
86     return asin(clamp!T(x, -1, 1));
87 }
88 
89 /// Safe acos: input clamped to [-1, 1]
90 @nogc T safeAcos(T)(T x) pure nothrow
91 {
92     return acos(clamp!T(x, -1, 1));
93 }
94 
95 /// Same as GLSL step function.
96 /// 0.0 is returned if x < edge, and 1.0 is returned otherwise.
97 @nogc T step(T)(T edge, T x) pure nothrow
98 {
99     return (x < edge) ? 0 : 1;
100 }
101 
102 /// Same as GLSL smoothstep function.
103 /// See: http://en.wikipedia.org/wiki/Smoothstep
104 @nogc T smoothStep(T)(T a, T b, T t) pure nothrow
105 {
106     if (t <= a)
107         return 0;
108     else if (t >= b)
109         return 1;
110     else
111     {
112         T x = (t - a) / (b - a);
113         return x * x * (3 - 2 * x);
114     }
115 }
116 
117 /// Returns: true of i is a power of 2.
118 @nogc bool isPowerOf2(T)(T i) pure nothrow if (isIntegral!T)
119 {
120     assert(i >= 0);
121     return (i != 0) && ((i & (i - 1)) == 0);
122 }
123 
124 /// Integer log2
125 /// TODO: use bt intrinsics
126 @nogc int ilog2(T)(T i) nothrow if (isIntegral!T)
127 {
128     assert(i > 0);
129     assert(isPowerOf2(i));
130     int result = 0;
131     while (i > 1)
132     {
133         i = i / 2;
134         result = result + 1;
135     }
136     return result;
137 }
138 
139 /// Computes next power of 2.
140 @nogc int nextPowerOf2(int i) pure nothrow
141 {
142     int v = i - 1;
143     v |= v >> 1;
144     v |= v >> 2;
145     v |= v >> 4;
146     v |= v >> 8;
147     v |= v >> 16;
148     v++;
149     assert(isPowerOf2(v));
150     return v;
151 }
152 
153 /// Computes next power of 2.
154 @nogc long nextPowerOf2(long i) pure nothrow
155 {
156     long v = i - 1;
157     v |= v >> 1;
158     v |= v >> 2;
159     v |= v >> 4;
160     v |= v >> 8;
161     v |= v >> 16;
162     v |= v >> 32;
163     v++;
164     assert(isPowerOf2(v));
165     return v;
166 }
167 
168 /// Computes sin(x)/x accurately.
169 /// See_also: $(WEB www.plunk.org/~hatch/rightway.php)
170 @nogc T sinOverX(T)(T x) pure nothrow
171 {
172     if (1 + x * x == 1)
173         return 1;
174     else
175         return sin(x) / x;
176 }
177 
178 
179 /// Signed integer modulo a/b where the remainder is guaranteed to be in [0..b[,
180 /// even if a is negative. Only support positive dividers.
181 @nogc T moduloWrap(T)(T a, T b) pure nothrow if (isSigned!T)
182 in
183 {
184     assert(b > 0);
185 }
186 body
187 {
188     if (a >= 0)
189         a = a % b;
190     else
191     {
192         auto rem = a % b;
193         x = (rem == 0) ? 0 : (-rem + b);
194     }
195 
196     assert(x >= 0 && x < b);
197     return x;
198 }
199 
200 unittest
201 {
202     assert(nextPowerOf2(13) == 16);
203 }
204 
205 /**
206  * Find the root of a linear polynomial a + b x = 0
207  * Returns: Number of roots.
208  */
209 @nogc int solveLinear(T)(T a, T b, out T root) pure nothrow if (isFloatingPoint!T)
210 {
211     if (b == 0)
212     {
213         return 0;
214     }
215     else
216     {
217         root = -a / b;
218         return 1;
219     }
220 }
221 
222 
223 /**
224  * Finds the root roots of a quadratic polynomial a + b x + c x^2 = 0
225  * Params:
226  *     outRoots = array of root results, should have room for at least 2 elements.
227  * Returns: Number of roots in outRoots.
228  */
229 @nogc int solveQuadratic(T)(T a, T b, T c, T[] outRoots) pure nothrow if (isFloatingPoint!T)
230 {
231     assert(outRoots.length >= 2);
232     if (c == 0)
233         return solveLinear(a, b, outRoots[0]);
234 
235     T delta = b * b - 4 * a * c;
236     if (delta < 0.0 )
237         return 0;
238 
239     delta = sqrt(delta);
240     T oneOver2a = 0.5 / a;
241 
242     outRoots[0] = oneOver2a * (-b - delta);
243     outRoots[1] = oneOver2a * (-b + delta);
244     return 2;
245 }
246 
247 
248 /**
249  * Finds the roots of a cubic polynomial  a + b x + c x^2 + d x^3 = 0
250  * Params:
251  *     outRoots = array of root results, should have room for at least 2 elements.
252  * Returns: Number of roots in outRoots.
253  * See_also: $(WEB www.codeguru.com/forum/archive/index.php/t-265551.html)
254  */
255 @nogc int solveCubic(T)(T a, T b, T c, T d, T[] outRoots) pure nothrow if (isFloatingPoint!T)
256 {
257     assert(outRoots.length >= 3);
258     if (d == 0)
259         return solveQuadratic(a, b, c, outRoots);
260 
261     // adjust coefficients
262     T a1 = c / d,
263       a2 = b / d,
264       a3 = a / d;
265 
266     T Q = (a1 * a1 - 3 * a2) / 9,
267       R = (2 * a1 * a1 * a1 - 9 * a1 * a2 + 27 * a3) / 54;
268 
269     T Qcubed = Q * Q * Q;
270     T d2 = Qcubed - R * R;
271 
272     if (d2 >= 0)
273     {
274         // 3 real roots
275         if (Q < 0.0)
276             return 0;
277         T P = R / sqrt(Qcubed);
278 
279         assert(-1 <= P && P <= 1);
280         T theta = acos(P);
281         T sqrtQ = sqrt(Q);
282 
283         outRoots[0] = -2 * sqrtQ * cos(theta / 3) - a1 / 3;
284         outRoots[1] = -2 * sqrtQ * cos((theta + 2 * PI) / 3) - a1 / 3;
285         outRoots[2] = -2 * sqrtQ * cos((theta + 4 * PI) / 3) - a1 / 3;
286         return 3;
287     }
288     else
289     {
290         // 1 real root
291         T e = (sqrt(-d) + abs(R)) ^^ cast(T)(1.0 / 3.0);
292         if (R > 0)
293             e = -e;
294         outRoots[0] = e + Q / e - a1 / 3.0;
295         return 1;
296     }
297 }
298 
299 /**
300  * Returns the roots of a quartic polynomial  a + b x + c x^2 + d x^3 + e x^4 = 0
301  *
302  * Returns number of roots. roots slice should have room for up to 4 elements.
303  * Bugs: doesn't pass unit-test!
304  * See_also: $(WEB mathworld.wolfram.com/QuarticEquation.html)
305  */
306 @nogc int solveQuartic(T)(T a, T b, T c, T d, T e, T[] roots) pure nothrow if (isFloatingPoint!T)
307 {
308     assert(roots.length >= 4);
309 
310     if (e == 0)
311         return solveCubic(a, b, c, d, roots);
312 
313     // Adjust coefficients
314     T a0 = a / e,
315       a1 = b / e,
316       a2 = c / e,
317       a3 = d / e;
318 
319     // Find a root for the following cubic equation:
320     //     y^3 - a2 y^2 + (a1 a3 - 4 a0) y + (4 a2 a0 - a1 ^2 - a3^2 a0) = 0
321     // aka Resolvent cubic
322     T b0 = 4 * a2 * a0 - a1 * a1 - a3 * a3 * a0;
323     T b1 = a1 * a3 - 4 * a0;
324     T b2 = -a2;
325     T[3] resolventCubicRoots;
326     int numRoots = solveCubic!T(b0, b1, b2, 1, resolventCubicRoots[]);
327     assert(numRoots == 3);
328     T y = resolventCubicRoots[0];
329     if (y < resolventCubicRoots[1]) y = resolventCubicRoots[1];
330     if (y < resolventCubicRoots[2]) y = resolventCubicRoots[2];
331 
332     // Compute R, D & E
333     T R = 0.25f * a3 * a3 - a2 + y;
334     if (R < 0.0)
335         return 0;
336     R = sqrt(R);
337 
338     T D = void,
339       E = void;
340     if (R == 0)
341     {
342         T d1 = 0.75f * a3 * a3 - 2 * a2;
343         T d2 = 2 * sqrt(y * y - 4 * a0);
344         D = sqrt(d1 + d2) * 0.5f;
345         E = sqrt(d1 - d2) * 0.5f;
346     }
347     else
348     {
349         T Rsquare = R * R;
350         T Rrec = 1 / R;
351         T d1 =  0.75f * a3 * a3 - Rsquare - 2 * a2;
352         T d2 = 0.25f * Rrec * (4 * a3 * a2 - 8 * a1 - a3 * a3 * a3);
353         D = sqrt(d1 + d2) * 0.5f;
354         E = sqrt(d1 - d2) * 0.5f;
355     }
356 
357     // Compute the 4 roots
358     a3 *= -0.25f;
359     R *= 0.5f;
360 
361     roots[0] = a3 + R + D;
362     roots[1] = a3 + R - D;
363     roots[2] = a3 - R + E;
364     roots[3] = a3 - R - E;
365     return 4;
366 }
367 
368 
369 unittest
370 {
371     bool arrayContainsRoot(double[] arr, double root)
372     {
373         foreach(e; arr)
374             if (abs(e - root) < 1e-7)
375                 return true;
376         return false;
377     }
378 
379     // test quadratic
380     {
381         double[3] roots;
382         int numRoots = solveCubic!double(-2, -3 / 2.0, 3 / 4.0, 1 / 4.0, roots[]);
383         assert(numRoots == 3);
384         assert(arrayContainsRoot(roots[], -4));
385         assert(arrayContainsRoot(roots[], -1));
386         assert(arrayContainsRoot(roots[], 2));
387     }
388 
389     // test quartic
390     {
391         double[4] roots;
392         int numRoots = solveQuartic!double(0, -2, -1, 2, 1, roots[]);
393 
394         assert(numRoots == 4);
395         assert(arrayContainsRoot(roots[], -2));
396         assert(arrayContainsRoot(roots[], -1));
397         assert(arrayContainsRoot(roots[], 0));
398         assert(arrayContainsRoot(roots[], 1));
399     }
400 }
401 
402 /// Arithmetic mean.
403 double average(R)(R r) if (isInputRange!R)
404 {
405     if (r.empty)
406         return double.nan;
407 
408     typeof(r.front()) sum = 0;
409     long count = 0;
410     foreach(e; r)
411     {
412         sum += e;
413         ++count;
414     }
415     return sum / count;
416 }
417 
418 /// Minimum of a range.
419 double minElement(R)(R r) if (isInputRange!R)
420 {
421     // do like Javascript for an empty range
422     if (r.empty)
423         return double.infinity;
424 
425     return minmax!("<", R)(r);
426 }
427 
428 /// Maximum of a range.
429 double maxElement(R)(R r) if (isInputRange!R)
430 {
431     // do like Javascript for an empty range
432     if (r.empty)
433         return -double.infinity;
434 
435     return minmax!(">", R)(r);
436 }
437 
438 /// Variance of a range.
439 double variance(R)(R r) if (isForwardRange!R)
440 {
441     if (r.empty)
442         return double.nan;
443 
444     auto avg = average(r.save); // getting the average
445 
446     typeof(avg) sum = 0;
447     long count = 0;
448     foreach(e; r)
449     {
450         sum += (e - avg) ^^ 2;
451         ++count;
452     }
453     if (count <= 1)
454         return 0.0;
455     else
456         return (sum / (count - 1.0)); // using sample std deviation as estimator
457 }
458 
459 /// Standard deviation of a range.
460 double standardDeviation(R)(R r) if (isForwardRange!R)
461 {
462     return sqrt(variance(r));
463 }
464 
465 private
466 {
467     typeof(R.front()) minmax(string op, R)(R r) if (isInputRange!R)
468     {
469         assert(!r.empty);
470         auto best = r.front();
471         r.popFront();
472         foreach(e; r)
473         {
474             mixin("if (e " ~ op ~ " best) best = e;");
475         }
476         return best;
477     }
478 }
479 
480 /// SSE approximation of reciprocal square root.
481 @nogc T inverseSqrt(T)(T x) pure nothrow if (isFloatingPoint!T)
482 {
483     version(AsmX86)
484     {
485         static if (is(T == float))
486         {
487             float result;
488 
489             static if( __VERSION__ >= 2067 )
490                 mixin(`asm pure nothrow @nogc { movss XMM0, x; rsqrtss XMM0, XMM0; movss result, XMM0; }`);
491             else
492                 mixin(`asm { movss XMM0, x; rsqrtss XMM0, XMM0; movss result, XMM0; }`);
493             return result;
494         }
495         else
496             return 1 / sqrt(x);
497     }
498     else
499         return 1 / sqrt(x);
500 }
501 
502 unittest
503 {
504     assert(abs( inverseSqrt!float(1) - 1) < 1e-3 );
505     assert(abs( inverseSqrt!double(1) - 1) < 1e-3 );
506 }