1 /**
2   Useful math functions and range-based statistic computations.
3 
4   If you need real statistics, consider using the $(WEB github.com/dsimcha/dstats,Dstats) library.
5  */
6 module gfm.math.funcs;
7 
8 import std.math,
9        std.traits,
10        std.range,
11        std.math;
12 
13 version( D_InlineAsm_X86 )
14 {
15     version = AsmX86;
16 }
17 else version( D_InlineAsm_X86_64 )
18 {
19     version = AsmX86;
20 }
21 
22 /// Convert from radians to degrees.
23 @nogc T degrees(T)(T x) pure nothrow if (!isIntegral!T)
24 {
25     return x * (180 / PI);
26 }
27 
28 /// Convert from degrees to radians.
29 @nogc T radians(T)(T x) pure nothrow if (!isIntegral!T)
30 {
31     return x * (PI / 180);
32 }
33 
34 /// Linear interpolation, akin to GLSL's mix.
35 @nogc S lerp(S, T)(S a, S b, T t) pure nothrow
36 {
37     return t * b + (1 - t) * a;
38 }
39 
40 /// Clamp x in [min, max], akin to GLSL's clamp.
41 @nogc T clamp(T)(T x, T min, T max) pure nothrow
42 {
43     if (x < min)
44         return min;
45     else if (x > max)
46         return max;
47     else
48         return x;
49 }
50 
51 /// Integer truncation.
52 @nogc long ltrunc(real x) nothrow // may be pure but trunc isn't pure
53 {
54     return cast(long)(trunc(x));
55 }
56 
57 /// Integer flooring.
58 @nogc long lfloor(real x) nothrow // may be pure but floor isn't pure
59 {
60     return cast(long)(floor(x));
61 }
62 
63 /// Returns: Fractional part of x.
64 @nogc T fract(T)(real x) nothrow
65 {
66     return x - lfloor(x);
67 }
68 
69 /// Safe asin: input clamped to [-1, 1]
70 @nogc T safeAsin(T)(T x) pure nothrow
71 {
72     return asin(clamp!T(x, -1, 1));
73 }
74 
75 /// Safe acos: input clamped to [-1, 1]
76 @nogc T safeAcos(T)(T x) pure nothrow
77 {
78     return acos(clamp!T(x, -1, 1));
79 }
80 
81 /// Same as GLSL step function.
82 /// 0.0 is returned if x < edge, and 1.0 is returned otherwise.
83 @nogc T step(T)(T edge, T x) pure nothrow
84 {
85     return (x < edge) ? 0 : 1;
86 }
87 
88 /// Same as GLSL smoothstep function.
89 /// See: http://en.wikipedia.org/wiki/Smoothstep
90 @nogc T smoothStep(T)(T a, T b, T t) pure nothrow
91 {
92     if (t <= a)
93         return 0;
94     else if (t >= b)
95         return 1;
96     else
97     {
98         T x = (t - a) / (b - a);
99         return x * x * (3 - 2 * x);
100     }
101 }
102 
103 /// Returns: true of i is a power of 2.
104 @nogc bool isPowerOf2(T)(T i) pure nothrow if (isIntegral!T)
105 {
106     assert(i >= 0);
107     return (i != 0) && ((i & (i - 1)) == 0);
108 }
109 
110 /// Integer log2
111 /// TODO: use bt intrinsics
112 @nogc int ilog2(T)(T i) nothrow if (isIntegral!T)
113 {
114     assert(i > 0);
115     assert(isPowerOf2(i));
116     int result = 0;
117     while (i > 1)
118     {
119         i = i / 2;
120         result = result + 1;
121     }
122     return result;
123 }
124 
125 /// Computes next power of 2.
126 @nogc int nextPowerOf2(int i) pure nothrow
127 {
128     int v = i - 1;
129     v |= v >> 1;
130     v |= v >> 2;
131     v |= v >> 4;
132     v |= v >> 8;
133     v |= v >> 16;
134     v++;
135     assert(isPowerOf2(v));
136     return v;
137 }
138 
139 /// Computes next power of 2.
140 @nogc long nextPowerOf2(long i) pure nothrow
141 {
142     long v = i - 1;
143     v |= v >> 1;
144     v |= v >> 2;
145     v |= v >> 4;
146     v |= v >> 8;
147     v |= v >> 16;
148     v |= v >> 32;
149     v++;
150     assert(isPowerOf2(v));
151     return v;
152 }
153 
154 /// Computes sin(x)/x accurately.
155 /// See_also: $(WEB www.plunk.org/~hatch/rightway.php)
156 @nogc T sinOverX(T)(T x) pure nothrow
157 {
158     if (1 + x * x == 1)
159         return 1;
160     else
161         return sin(x) / x;
162 }
163 
164 
165 /// Signed integer modulo a/b where the remainder is guaranteed to be in [0..b[,
166 /// even if a is negative. Only support positive dividers.
167 @nogc T moduloWrap(T)(T a, T b) pure nothrow if (isSigned!T)
168 in
169 {
170     assert(b > 0);
171 }
172 body
173 {
174     if (a >= 0)
175         a = a % b;
176     else
177     {
178         auto rem = a % b;
179         x = (rem == 0) ? 0 : (-rem + b);
180     }
181 
182     assert(x >= 0 && x < b);
183     return x;
184 }
185 
186 unittest
187 {
188     assert(nextPowerOf2(13) == 16);
189 }
190 
191 /**
192  * Find the root of a linear polynomial a + b x = 0
193  * Returns: Number of roots.
194  */
195 @nogc int solveLinear(T)(T a, T b, out T root) pure nothrow if (isFloatingPoint!T)
196 {
197     if (b == 0)
198     {
199         return 0;
200     }
201     else
202     {
203         root = -a / b;
204         return 1;
205     }
206 }
207 
208 
209 /**
210  * Finds the root roots of a quadratic polynomial a + b x + c x^2 = 0
211  * Params:
212  *     outRoots = array of root results, should have room for at least 2 elements.
213  * Returns: Number of roots in outRoots.
214  */
215 @nogc int solveQuadratic(T)(T a, T b, T c, T[] outRoots) pure nothrow if (isFloatingPoint!T)
216 {
217     assert(outRoots.length >= 2);
218     if (c == 0)
219         return solveLinear(a, b, outRoots[0]);
220 
221     T delta = b * b - 4 * a * c;
222     if (delta < 0.0 )
223         return 0;
224 
225     delta = sqrt(delta);
226     T oneOver2a = 0.5 / a;
227 
228     outRoots[0] = oneOver2a * (-b - delta);
229     outRoots[1] = oneOver2a * (-b + delta);
230     return 2;
231 }
232 
233 
234 /**
235  * Finds the roots of a cubic polynomial  a + b x + c x^2 + d x^3 = 0
236  * Params:
237  *     outRoots = array of root results, should have room for at least 2 elements.
238  * Returns: Number of roots in outRoots.
239  * See_also: $(WEB www.codeguru.com/forum/archive/index.php/t-265551.html)
240  */
241 @nogc int solveCubic(T)(T a, T b, T c, T d, T[] outRoots) pure nothrow if (isFloatingPoint!T)
242 {
243     assert(outRoots.length >= 3);
244     if (d == 0)
245         return solveQuadratic(a, b, c, outRoots);
246 
247     // adjust coefficients
248     T a1 = c / d,
249       a2 = b / d,
250       a3 = a / d;
251 
252     T Q = (a1 * a1 - 3 * a2) / 9,
253       R = (2 * a1 * a1 * a1 - 9 * a1 * a2 + 27 * a3) / 54;
254 
255     T Qcubed = Q * Q * Q;
256     T d2 = Qcubed - R * R;
257 
258     if (d2 >= 0)
259     {
260         // 3 real roots
261         if (Q < 0.0)
262             return 0;
263         T P = R / sqrt(Qcubed);
264 
265         assert(-1 <= P && P <= 1);
266         T theta = acos(P);
267         T sqrtQ = sqrt(Q);
268 
269         outRoots[0] = -2 * sqrtQ * cos(theta / 3) - a1 / 3;
270         outRoots[1] = -2 * sqrtQ * cos((theta + 2 * PI) / 3) - a1 / 3;
271         outRoots[2] = -2 * sqrtQ * cos((theta + 4 * PI) / 3) - a1 / 3;
272         return 3;
273     }
274     else
275     {
276         // 1 real root
277         T e = (sqrt(-d) + abs(R)) ^^ cast(T)(1.0 / 3.0);
278         if (R > 0)
279             e = -e;
280         outRoots[0] = e + Q / e - a1 / 3.0;
281         return 1;
282     }
283 }
284 
285 /**
286  * Returns the roots of a quartic polynomial  a + b x + c x^2 + d x^3 + e x^4 = 0
287  *
288  * Returns number of roots. roots slice should have room for up to 4 elements.
289  * Bugs: doesn't pass unit-test!
290  * See_also: $(WEB mathworld.wolfram.com/QuarticEquation.html)
291  */
292 @nogc int solveQuartic(T)(T a, T b, T c, T d, T e, T[] roots) pure nothrow if (isFloatingPoint!T)
293 {
294     assert(roots.length >= 4);
295 
296     if (e == 0)
297         return solveCubic(a, b, c, d, roots);
298 
299     // Adjust coefficients
300     T a0 = a / e,
301       a1 = b / e,
302       a2 = c / e,
303       a3 = d / e;
304 
305     // Find a root for the following cubic equation:
306     //     y^3 - a2 y^2 + (a1 a3 - 4 a0) y + (4 a2 a0 - a1 ^2 - a3^2 a0) = 0
307     // aka Resolvent cubic
308     T b0 = 4 * a2 * a0 - a1 * a1 - a3 * a3 * a0;
309     T b1 = a1 * a3 - 4 * a0;
310     T b2 = -a2;
311     T[3] resolventCubicRoots;
312     int numRoots = solveCubic!T(b0, b1, b2, 1, resolventCubicRoots[]);
313     assert(numRoots == 3);
314     T y = resolventCubicRoots[0];
315     if (y < resolventCubicRoots[1]) y = resolventCubicRoots[1];
316     if (y < resolventCubicRoots[2]) y = resolventCubicRoots[2];
317 
318     // Compute R, D & E
319     T R = 0.25f * a3 * a3 - a2 + y;
320     if (R < 0.0)
321         return 0;
322     R = sqrt(R);
323 
324     T D = void,
325       E = void;
326     if (R == 0)
327     {
328         T d1 = 0.75f * a3 * a3 - 2 * a2;
329         T d2 = 2 * sqrt(y * y - 4 * a0);
330         D = sqrt(d1 + d2) * 0.5f;
331         E = sqrt(d1 - d2) * 0.5f;
332     }
333     else
334     {
335         T Rsquare = R * R;
336         T Rrec = 1 / R;
337         T d1 =  0.75f * a3 * a3 - Rsquare - 2 * a2;
338         T d2 = 0.25f * Rrec * (4 * a3 * a2 - 8 * a1 - a3 * a3 * a3);
339         D = sqrt(d1 + d2) * 0.5f;
340         E = sqrt(d1 - d2) * 0.5f;
341     }
342 
343     // Compute the 4 roots
344     a3 *= -0.25f;
345     R *= 0.5f;
346 
347     roots[0] = a3 + R + D;
348     roots[1] = a3 + R - D;
349     roots[2] = a3 - R + E;
350     roots[3] = a3 - R - E;
351     return 4;
352 }
353 
354 
355 unittest
356 {
357     bool arrayContainsRoot(double[] arr, double root)
358     {
359         foreach(e; arr)
360             if (abs(e - root) < 1e-7)
361                 return true;
362         return false;
363     }
364 
365     // test quadratic
366     {
367         double[3] roots;
368         int numRoots = solveCubic!double(-2, -3 / 2.0, 3 / 4.0, 1 / 4.0, roots[]);
369         assert(numRoots == 3);
370         assert(arrayContainsRoot(roots[], -4));
371         assert(arrayContainsRoot(roots[], -1));
372         assert(arrayContainsRoot(roots[], 2));
373     }
374 
375     // test quartic
376     {
377         double[4] roots;
378         int numRoots = solveQuartic!double(0, -2, -1, 2, 1, roots[]);
379 
380         assert(numRoots == 4);
381         assert(arrayContainsRoot(roots[], -2));
382         assert(arrayContainsRoot(roots[], -1));
383         assert(arrayContainsRoot(roots[], 0));
384         assert(arrayContainsRoot(roots[], 1));
385     }
386 }
387 
388 /// Arithmetic mean.
389 double average(R)(R r) if (isInputRange!R)
390 {
391     if (r.empty)
392         return double.nan;
393 
394     typeof(r.front()) sum = 0;
395     long count = 0;
396     foreach(e; r)
397     {
398         sum += e;
399         ++count;
400     }
401     return sum / count;
402 }
403 
404 /// Minimum of a range.
405 double minElement(R)(R r) if (isInputRange!R)
406 {
407     // do like Javascript for an empty range
408     if (r.empty)
409         return double.infinity;
410 
411     return minmax!("<", R)(r);
412 }
413 
414 /// Maximum of a range.
415 double maxElement(R)(R r) if (isInputRange!R)
416 {
417     // do like Javascript for an empty range
418     if (r.empty)
419         return -double.infinity;
420 
421     return minmax!(">", R)(r);
422 }
423 
424 /// Variance of a range.
425 double variance(R)(R r) if (isForwardRange!R)
426 {
427     if (r.empty)
428         return double.nan;
429 
430     auto avg = average(r.save); // getting the average
431 
432     typeof(avg) sum = 0;
433     long count = 0;
434     foreach(e; r)
435     {
436         sum += (e - avg) ^^ 2;
437         ++count;
438     }
439     if (count <= 1)
440         return 0.0;
441     else
442         return (sum / (count - 1.0)); // using sample std deviation as estimator
443 }
444 
445 /// Standard deviation of a range.
446 double standardDeviation(R)(R r) if (isForwardRange!R)
447 {
448     return sqrt(variance(r));
449 }
450 
451 private
452 {
453     typeof(R.front()) minmax(string op, R)(R r) if (isInputRange!R)
454     {
455         assert(!r.empty);
456         auto best = r.front();
457         r.popFront();
458         foreach(e; r)
459         {
460             mixin("if (e " ~ op ~ " best) best = e;");
461         }
462         return best;
463     }
464 }
465 
466 /// SSE approximation of reciprocal square root.
467 @nogc T inverseSqrt(T)(T x) pure nothrow if (isFloatingPoint!T)
468 {
469     version(AsmX86)
470     {
471         static if (is(T == float))
472         {
473             float result;
474 
475             static if( __VERSION__ >= 2067 )
476                 mixin(`asm pure nothrow @nogc { movss XMM0, x; rsqrtss XMM0, XMM0; movss result, XMM0; }`);
477             else
478                 mixin(`asm { movss XMM0, x; rsqrtss XMM0, XMM0; movss result, XMM0; }`);
479             return result;
480         }
481         else
482             return 1 / sqrt(x);
483     }
484     else
485         return 1 / sqrt(x);
486 }
487 
488 unittest
489 {
490     assert(abs( inverseSqrt!float(1) - 1) < 1e-3 );
491     assert(abs( inverseSqrt!double(1) - 1) < 1e-3 );
492 }