1 /**
2   Useful math functions and range-based statistic computations.
3 
4   If you need real statistics, consider using the $(WEB github.com/dsimcha/dstats,Dstats) library.
5  */
6 module gfm.math.funcs;
7 
8 import std.math,
9        std.traits,
10        std.range,
11        std.math;
12 
13 import inteli.xmmintrin;
14 
15 import gfm.math.vector : Vector;
16 
17 /// Convert from radians to degrees.
18 @nogc T degrees(T)(in T x) pure nothrow
19 if (isFloatingPoint!T || (is(T : Vector!(U, n), U, int n) && isFloatingPoint!U))
20 {
21     static if (is(T : Vector!(U, n), U, int n))
22         return x * U(180 / PI);
23     else
24         return x * T(180 / PI);
25 }
26 
27 /// Convert from degrees to radians.
28 @nogc T radians(T)(in T x) pure nothrow
29 if (isFloatingPoint!T || (is(T : Vector!(U, n), U, int n) && isFloatingPoint!U))
30 {
31     static if (is(T : Vector!(U, n), U, int n))
32         return x * U(PI / 180);
33     else
34         return x * T(PI / 180);
35 }
36 
37 /// Linear interpolation, akin to GLSL's mix.
38 @nogc S lerp(S, T)(S a, S b, T t) pure nothrow
39   if (is(typeof(t * b + (1 - t) * a) : S))
40 {
41     return t * b + (1 - t) * a;
42 }
43 
44 /// Clamp x in [min, max], akin to GLSL's clamp.
45 @nogc T clamp(T)(T x, T min, T max) pure nothrow
46 {
47     if (x < min)
48         return min;
49     else if (x > max)
50         return max;
51     else
52         return x;
53 }
54 
55 /// Integer truncation.
56 @nogc long ltrunc(real x) nothrow // may be pure but trunc isn't pure
57 {
58     return cast(long)(trunc(x));
59 }
60 
61 /// Integer flooring.
62 @nogc long lfloor(real x) nothrow // may be pure but floor isn't pure
63 {
64     return cast(long)(floor(x));
65 }
66 
67 /// Returns: Fractional part of x.
68 @nogc T fract(T)(real x) nothrow
69 {
70     return x - lfloor(x);
71 }
72 
73 /// Safe asin: input clamped to [-1, 1]
74 @nogc T safeAsin(T)(T x) pure nothrow
75 {
76     return asin(clamp!T(x, -1, 1));
77 }
78 
79 /// Safe acos: input clamped to [-1, 1]
80 @nogc T safeAcos(T)(T x) pure nothrow
81 {
82     return acos(clamp!T(x, -1, 1));
83 }
84 
85 /// Same as GLSL step function.
86 /// 0.0 is returned if x < edge, and 1.0 is returned otherwise.
87 @nogc T step(T)(T edge, T x) pure nothrow
88 {
89     return (x < edge) ? 0 : 1;
90 }
91 
92 /// Same as GLSL smoothstep function.
93 /// See: http://en.wikipedia.org/wiki/Smoothstep
94 @nogc T smoothStep(T)(T a, T b, T t) pure nothrow
95 {
96     if (t <= a)
97         return 0;
98     else if (t >= b)
99         return 1;
100     else
101     {
102         T x = (t - a) / (b - a);
103         return x * x * (3 - 2 * x);
104     }
105 }
106 
107 /// Returns: true of i is a power of 2.
108 @nogc bool isPowerOf2(T)(T i) pure nothrow if (isIntegral!T)
109 {
110     assert(i >= 0);
111     return (i != 0) && ((i & (i - 1)) == 0);
112 }
113 
114 /// Integer log2
115 @nogc int ilog2(T)(T i) nothrow if (isIntegral!T)
116 {
117     assert(i > 0);
118     assert(isPowerOf2(i));
119     import core.bitop : bsr;
120     return bsr(i);
121 }
122 
123 /// Computes next power of 2.
124 @nogc int nextPowerOf2(int i) pure nothrow
125 {
126     int v = i - 1;
127     v |= v >> 1;
128     v |= v >> 2;
129     v |= v >> 4;
130     v |= v >> 8;
131     v |= v >> 16;
132     v++;
133     assert(isPowerOf2(v));
134     return v;
135 }
136 
137 /// Computes next power of 2.
138 @nogc long nextPowerOf2(long i) pure nothrow
139 {
140     long v = i - 1;
141     v |= v >> 1;
142     v |= v >> 2;
143     v |= v >> 4;
144     v |= v >> 8;
145     v |= v >> 16;
146     v |= v >> 32;
147     v++;
148     assert(isPowerOf2(v));
149     return v;
150 }
151 
152 /// Computes sin(x)/x accurately.
153 /// See_also: $(WEB www.plunk.org/~hatch/rightway.php)
154 @nogc T sinOverX(T)(T x) pure nothrow
155 {
156     if (1 + x * x == 1)
157         return 1;
158     else
159         return sin(x) / x;
160 }
161 
162 
163 /// Signed integer modulo a/b where the remainder is guaranteed to be in [0..b[,
164 /// even if a is negative. Only support positive dividers.
165 @nogc T moduloWrap(T)(T a, T b) pure nothrow if (isSigned!T)
166 in
167 {
168     assert(b > 0);
169 }
170 body
171 {
172     if (a >= 0)
173         a = a % b;
174     else
175     {
176         auto rem = a % b;
177         x = (rem == 0) ? 0 : (-rem + b);
178     }
179 
180     assert(x >= 0 && x < b);
181     return x;
182 }
183 
184 unittest
185 {
186     assert(nextPowerOf2(13) == 16);
187 }
188 
189 /**
190  * Find the root of a linear polynomial a + b x = 0
191  * Returns: Number of roots.
192  */
193 @nogc int solveLinear(T)(T a, T b, out T root) pure nothrow if (isFloatingPoint!T)
194 {
195     if (b == 0)
196     {
197         return 0;
198     }
199     else
200     {
201         root = -a / b;
202         return 1;
203     }
204 }
205 
206 
207 /**
208  * Finds the root roots of a quadratic polynomial a + b x + c x^2 = 0
209  * Params:
210  *     a = Coefficient.
211  *     b = Coefficient.
212  *     c = Coefficient.
213  *     outRoots = array of root results, should have room for at least 2 elements.
214  * Returns: Number of roots in outRoots.
215  */
216 @nogc int solveQuadratic(T)(T a, T b, T c, T[] outRoots) pure nothrow if (isFloatingPoint!T)
217 {
218     assert(outRoots.length >= 2);
219     if (c == 0)
220         return solveLinear(a, b, outRoots[0]);
221 
222     T delta = b * b - 4 * a * c;
223     if (delta < 0.0 )
224         return 0;
225 
226     delta = sqrt(delta);
227     T oneOver2a = 0.5 / a;
228 
229     outRoots[0] = oneOver2a * (-b - delta);
230     outRoots[1] = oneOver2a * (-b + delta);
231     return 2;
232 }
233 
234 
235 /**
236  * Finds the roots of a cubic polynomial  a + b x + c x^2 + d x^3 = 0
237  * Params:
238  *     a = Coefficient.
239  *     b = Coefficient.
240  *     c = Coefficient.
241  *     d = Coefficient.
242  *     outRoots = array of root results, should have room for at least 2 elements.
243  * Returns: Number of roots in outRoots.
244  * See_also: $(WEB www.codeguru.com/forum/archive/index.php/t-265551.html)
245  */
246 @nogc int solveCubic(T)(T a, T b, T c, T d, T[] outRoots) pure nothrow if (isFloatingPoint!T)
247 {
248     assert(outRoots.length >= 3);
249     if (d == 0)
250         return solveQuadratic(a, b, c, outRoots);
251 
252     // adjust coefficients
253     T a1 = c / d,
254       a2 = b / d,
255       a3 = a / d;
256 
257     T Q = (a1 * a1 - 3 * a2) / 9,
258       R = (2 * a1 * a1 * a1 - 9 * a1 * a2 + 27 * a3) / 54;
259 
260     T Qcubed = Q * Q * Q;
261     T d2 = Qcubed - R * R;
262 
263     if (d2 >= 0)
264     {
265         // 3 real roots
266         if (Q < 0.0)
267             return 0;
268         T P = R / sqrt(Qcubed);
269 
270         assert(-1 <= P && P <= 1);
271         T theta = acos(P);
272         T sqrtQ = sqrt(Q);
273 
274         outRoots[0] = -2 * sqrtQ * cos(theta / 3) - a1 / 3;
275         outRoots[1] = -2 * sqrtQ * cos((theta + 2 * T(PI)) / 3) - a1 / 3;
276         outRoots[2] = -2 * sqrtQ * cos((theta + 4 * T(PI)) / 3) - a1 / 3;
277         return 3;
278     }
279     else
280     {
281         // 1 real root
282         T e = (sqrt(-d) + abs(R)) ^^ cast(T)(1.0 / 3.0);
283         if (R > 0)
284             e = -e;
285         outRoots[0] = e + Q / e - a1 / 3.0;
286         return 1;
287     }
288 }
289 
290 /**
291  * Returns the roots of a quartic polynomial  a + b x + c x^2 + d x^3 + e x^4 = 0
292  *
293  * Returns number of roots. roots slice should have room for up to 4 elements.
294  * Bugs: doesn't pass unit-test!
295  * See_also: $(WEB mathworld.wolfram.com/QuarticEquation.html)
296  */
297 @nogc int solveQuartic(T)(T a, T b, T c, T d, T e, T[] roots) pure nothrow if (isFloatingPoint!T)
298 {
299     assert(roots.length >= 4);
300 
301     if (e == 0)
302         return solveCubic(a, b, c, d, roots);
303 
304     // Adjust coefficients
305     T a0 = a / e,
306       a1 = b / e,
307       a2 = c / e,
308       a3 = d / e;
309 
310     // Find a root for the following cubic equation:
311     //     y^3 - a2 y^2 + (a1 a3 - 4 a0) y + (4 a2 a0 - a1 ^2 - a3^2 a0) = 0
312     // aka Resolvent cubic
313     T b0 = 4 * a2 * a0 - a1 * a1 - a3 * a3 * a0;
314     T b1 = a1 * a3 - 4 * a0;
315     T b2 = -a2;
316     T[3] resolventCubicRoots;
317     int numRoots = solveCubic!T(b0, b1, b2, 1, resolventCubicRoots[]);
318     assert(numRoots == 3);
319     T y = resolventCubicRoots[0];
320     if (y < resolventCubicRoots[1]) y = resolventCubicRoots[1];
321     if (y < resolventCubicRoots[2]) y = resolventCubicRoots[2];
322 
323     // Compute R, D & E
324     T R = 0.25f * a3 * a3 - a2 + y;
325     if (R < 0.0)
326         return 0;
327     R = sqrt(R);
328 
329     T D = void,
330       E = void;
331     if (R == 0)
332     {
333         T d1 = 0.75f * a3 * a3 - 2 * a2;
334         T d2 = 2 * sqrt(y * y - 4 * a0);
335         D = sqrt(d1 + d2) * 0.5f;
336         E = sqrt(d1 - d2) * 0.5f;
337     }
338     else
339     {
340         T Rsquare = R * R;
341         T Rrec = 1 / R;
342         T d1 =  0.75f * a3 * a3 - Rsquare - 2 * a2;
343         T d2 = 0.25f * Rrec * (4 * a3 * a2 - 8 * a1 - a3 * a3 * a3);
344         D = sqrt(d1 + d2) * 0.5f;
345         E = sqrt(d1 - d2) * 0.5f;
346     }
347 
348     // Compute the 4 roots
349     a3 *= -0.25f;
350     R *= 0.5f;
351 
352     roots[0] = a3 + R + D;
353     roots[1] = a3 + R - D;
354     roots[2] = a3 - R + E;
355     roots[3] = a3 - R - E;
356     return 4;
357 }
358 
359 
360 unittest
361 {
362     bool arrayContainsRoot(double[] arr, double root)
363     {
364         foreach(e; arr)
365             if (abs(e - root) < 1e-7)
366                 return true;
367         return false;
368     }
369 
370     // test quadratic
371     {
372         double[3] roots;
373         int numRoots = solveCubic!double(-2, -3 / 2.0, 3 / 4.0, 1 / 4.0, roots[]);
374         assert(numRoots == 3);
375         assert(arrayContainsRoot(roots[], -4));
376         assert(arrayContainsRoot(roots[], -1));
377         assert(arrayContainsRoot(roots[], 2));
378     }
379 
380     // test quartic
381     {
382         double[4] roots;
383         int numRoots = solveQuartic!double(0, -2, -1, 2, 1, roots[]);
384 
385         assert(numRoots == 4);
386         assert(arrayContainsRoot(roots[], -2));
387         assert(arrayContainsRoot(roots[], -1));
388         assert(arrayContainsRoot(roots[], 0));
389         assert(arrayContainsRoot(roots[], 1));
390     }
391 }
392 
393 /// Arithmetic mean.
394 double average(R)(R r) if (isInputRange!R)
395 {
396     if (r.empty)
397         return double.nan;
398 
399     typeof(r.front()) sum = 0;
400     long count = 0;
401     foreach(e; r)
402     {
403         sum += e;
404         ++count;
405     }
406     return sum / count;
407 }
408 
409 /// Minimum of a range.
410 double minElement(R)(R r) if (isInputRange!R)
411 {
412     // do like Javascript for an empty range
413     if (r.empty)
414         return double.infinity;
415 
416     return minmax!("<", R)(r);
417 }
418 
419 /// Maximum of a range.
420 double maxElement(R)(R r) if (isInputRange!R)
421 {
422     // do like Javascript for an empty range
423     if (r.empty)
424         return -double.infinity;
425 
426     return minmax!(">", R)(r);
427 }
428 
429 /// Variance of a range.
430 double variance(R)(R r) if (isForwardRange!R)
431 {
432     if (r.empty)
433         return double.nan;
434 
435     auto avg = average(r.save); // getting the average
436 
437     typeof(avg) sum = 0;
438     long count = 0;
439     foreach(e; r)
440     {
441         sum += (e - avg) ^^ 2;
442         ++count;
443     }
444     if (count <= 1)
445         return 0.0;
446     else
447         return (sum / (count - 1.0)); // using sample std deviation as estimator
448 }
449 
450 /// Standard deviation of a range.
451 double standardDeviation(R)(R r) if (isForwardRange!R)
452 {
453     return sqrt(variance(r));
454 }
455 
456 private
457 {
458     typeof(R.front()) minmax(string op, R)(R r) if (isInputRange!R)
459     {
460         assert(!r.empty);
461         auto best = r.front();
462         r.popFront();
463         foreach(e; r)
464         {
465             mixin("if (e " ~ op ~ " best) best = e;");
466         }
467         return best;
468     }
469 }
470 
471 /// SSE approximation of reciprocal square root.
472 @nogc T inverseSqrt(T)(T x) pure nothrow if (isFloatingPoint!T)
473 {
474     static if (is(T == float))
475     {
476         __m128 V = _mm_set_ss(x);
477         V = _mm_rsqrt_ss(V);
478         return _mm_cvtss_f32(V);
479     }
480     else
481     {
482         return 1 / sqrt(x);
483     }
484 }
485 
486 unittest
487 {
488     assert(abs( inverseSqrt!float(1) - 1) < 1e-3 );
489     assert(abs( inverseSqrt!double(1) - 1) < 1e-3 );
490 }