1 /**
2   Useful math functions and range-based statistic computations.
3 
4   If you need real statistics, consider using the $(WEB github.com/dsimcha/dstats,Dstats) library.
5  */
6 module gfm.math.funcs;
7 
8 import std.math,
9        std.traits,
10        std.range,
11        std.math;
12 
13 import inteli.xmmintrin;
14 
15 import gfm.math.vector : Vector;
16 
17 /// Convert from radians to degrees.
18 @nogc T degrees(T)(in T x) pure nothrow
19 if (isFloatingPoint!T || (is(T : Vector!(U, n), U, int n) && isFloatingPoint!U))
20 {
21     static if (is(T : Vector!(U, n), U, int n))
22         return x * U(180 / PI);
23     else
24         return x * T(180 / PI);
25 }
26 
27 /// Convert from degrees to radians.
28 @nogc T radians(T)(in T x) pure nothrow
29 if (isFloatingPoint!T || (is(T : Vector!(U, n), U, int n) && isFloatingPoint!U))
30 {
31     static if (is(T : Vector!(U, n), U, int n))
32         return x * U(PI / 180);
33     else
34         return x * T(PI / 180);
35 }
36 
37 /// Linear interpolation, akin to GLSL's mix.
38 @nogc S lerp(S, T)(S a, S b, T t) pure nothrow
39   if (is(typeof(t * b + (1 - t) * a) : S))
40 {
41     return t * b + (1 - t) * a;
42 }
43 
44 /// Clamp x in [min, max], akin to GLSL's clamp.
45 @nogc T clamp(T)(T x, T min, T max) pure nothrow
46 {
47     if (x < min)
48         return min;
49     else if (x > max)
50         return max;
51     else
52         return x;
53 }
54 
55 /// Integer truncation.
56 @nogc long ltrunc(real x) nothrow // may be pure but trunc isn't pure
57 {
58     return cast(long)(trunc(x));
59 }
60 
61 /// Integer flooring.
62 @nogc long lfloor(real x) nothrow // may be pure but floor isn't pure
63 {
64     return cast(long)(floor(x));
65 }
66 
67 /// Returns: Fractional part of x.
68 @nogc T fract(T)(real x) nothrow
69 {
70     return x - lfloor(x);
71 }
72 
73 /// Safe asin: input clamped to [-1, 1]
74 @nogc T safeAsin(T)(T x) pure nothrow
75 {
76     return asin(clamp!T(x, -1, 1));
77 }
78 
79 /// Safe acos: input clamped to [-1, 1]
80 @nogc T safeAcos(T)(T x) pure nothrow
81 {
82     return acos(clamp!T(x, -1, 1));
83 }
84 
85 /// Same as GLSL step function.
86 /// 0.0 is returned if x < edge, and 1.0 is returned otherwise.
87 @nogc T step(T)(T edge, T x) pure nothrow
88 {
89     return (x < edge) ? 0 : 1;
90 }
91 
92 /// Same as GLSL smoothstep function.
93 /// See: http://en.wikipedia.org/wiki/Smoothstep
94 @nogc T smoothStep(T)(T a, T b, T t) pure nothrow
95 {
96     if (t <= a)
97         return 0;
98     else if (t >= b)
99         return 1;
100     else
101     {
102         T x = (t - a) / (b - a);
103         return x * x * (3 - 2 * x);
104     }
105 }
106 
107 /// Returns: true of i is a power of 2.
108 @nogc bool isPowerOf2(T)(T i) pure nothrow if (isIntegral!T)
109 {
110     assert(i >= 0);
111     return (i != 0) && ((i & (i - 1)) == 0);
112 }
113 
114 /// Integer log2
115 @nogc int ilog2(T)(T i) nothrow if (isIntegral!T)
116 {
117     assert(i > 0);
118     assert(isPowerOf2(i));
119     import core.bitop : bsr;
120     return bsr(i);
121 }
122 
123 /// Computes next power of 2.
124 @nogc int nextPowerOf2(int i) pure nothrow
125 {
126     int v = i - 1;
127     v |= v >> 1;
128     v |= v >> 2;
129     v |= v >> 4;
130     v |= v >> 8;
131     v |= v >> 16;
132     v++;
133     assert(isPowerOf2(v));
134     return v;
135 }
136 
137 /// Computes next power of 2.
138 @nogc long nextPowerOf2(long i) pure nothrow
139 {
140     long v = i - 1;
141     v |= v >> 1;
142     v |= v >> 2;
143     v |= v >> 4;
144     v |= v >> 8;
145     v |= v >> 16;
146     v |= v >> 32;
147     v++;
148     assert(isPowerOf2(v));
149     return v;
150 }
151 
152 /// Computes sin(x)/x accurately.
153 /// See_also: $(WEB www.plunk.org/~hatch/rightway.php)
154 @nogc T sinOverX(T)(T x) pure nothrow
155 {
156     if (1 + x * x == 1)
157         return 1;
158     else
159         return sin(x) / x;
160 }
161 
162 
163 /// Signed integer modulo a/b where the remainder is guaranteed to be in [0..b[,
164 /// even if a is negative. Only support positive dividers.
165 @nogc T moduloWrap(T)(T a, T b) pure nothrow if (isSigned!T)
166 {
167     assert(b > 0);
168     if (a >= 0)
169         a = a % b;
170     else
171     {
172         auto rem = a % b;
173         x = (rem == 0) ? 0 : (-rem + b);
174     }
175 
176     assert(x >= 0 && x < b);
177     return x;
178 }
179 
180 unittest
181 {
182     assert(nextPowerOf2(13) == 16);
183 }
184 
185 /**
186  * Find the root of a linear polynomial a + b x = 0
187  * Returns: Number of roots.
188  */
189 @nogc int solveLinear(T)(T a, T b, out T root) pure nothrow if (isFloatingPoint!T)
190 {
191     if (b == 0)
192     {
193         return 0;
194     }
195     else
196     {
197         root = -a / b;
198         return 1;
199     }
200 }
201 
202 
203 /**
204  * Finds the root roots of a quadratic polynomial a + b x + c x^2 = 0
205  * Params:
206  *     a = Coefficient.
207  *     b = Coefficient.
208  *     c = Coefficient.
209  *     outRoots = array of root results, should have room for at least 2 elements.
210  * Returns: Number of roots in outRoots.
211  */
212 @nogc int solveQuadratic(T)(T a, T b, T c, T[] outRoots) pure nothrow if (isFloatingPoint!T)
213 {
214     assert(outRoots.length >= 2);
215     if (c == 0)
216         return solveLinear(a, b, outRoots[0]);
217 
218     T delta = b * b - 4 * a * c;
219     if (delta < 0.0 )
220         return 0;
221 
222     delta = sqrt(delta);
223     T oneOver2a = 0.5 / a;
224 
225     outRoots[0] = oneOver2a * (-b - delta);
226     outRoots[1] = oneOver2a * (-b + delta);
227     return 2;
228 }
229 
230 
231 /**
232  * Finds the roots of a cubic polynomial  a + b x + c x^2 + d x^3 = 0
233  * Params:
234  *     a = Coefficient.
235  *     b = Coefficient.
236  *     c = Coefficient.
237  *     d = Coefficient.
238  *     outRoots = array of root results, should have room for at least 2 elements.
239  * Returns: Number of roots in outRoots.
240  * See_also: $(WEB www.codeguru.com/forum/archive/index.php/t-265551.html)
241  */
242 @nogc int solveCubic(T)(T a, T b, T c, T d, T[] outRoots) pure nothrow if (isFloatingPoint!T)
243 {
244     assert(outRoots.length >= 3);
245     if (d == 0)
246         return solveQuadratic(a, b, c, outRoots);
247 
248     // adjust coefficients
249     T a1 = c / d,
250       a2 = b / d,
251       a3 = a / d;
252 
253     T Q = (a1 * a1 - 3 * a2) / 9,
254       R = (2 * a1 * a1 * a1 - 9 * a1 * a2 + 27 * a3) / 54;
255 
256     T Qcubed = Q * Q * Q;
257     T d2 = Qcubed - R * R;
258 
259     if (d2 >= 0)
260     {
261         // 3 real roots
262         if (Q < 0.0)
263             return 0;
264         T P = R / sqrt(Qcubed);
265 
266         assert(-1 <= P && P <= 1);
267         T theta = acos(P);
268         T sqrtQ = sqrt(Q);
269 
270         outRoots[0] = -2 * sqrtQ * cos(theta / 3) - a1 / 3;
271         outRoots[1] = -2 * sqrtQ * cos((theta + 2 * T(PI)) / 3) - a1 / 3;
272         outRoots[2] = -2 * sqrtQ * cos((theta + 4 * T(PI)) / 3) - a1 / 3;
273         return 3;
274     }
275     else
276     {
277         // 1 real root
278         T e = (sqrt(-d) + abs(R)) ^^ cast(T)(1.0 / 3.0);
279         if (R > 0)
280             e = -e;
281         outRoots[0] = e + Q / e - a1 / 3.0;
282         return 1;
283     }
284 }
285 
286 /**
287  * Returns the roots of a quartic polynomial  a + b x + c x^2 + d x^3 + e x^4 = 0
288  *
289  * Returns number of roots. roots slice should have room for up to 4 elements.
290  * Bugs: doesn't pass unit-test!
291  * See_also: $(WEB mathworld.wolfram.com/QuarticEquation.html)
292  */
293 @nogc int solveQuartic(T)(T a, T b, T c, T d, T e, T[] roots) pure nothrow if (isFloatingPoint!T)
294 {
295     assert(roots.length >= 4);
296 
297     if (e == 0)
298         return solveCubic(a, b, c, d, roots);
299 
300     // Adjust coefficients
301     T a0 = a / e,
302       a1 = b / e,
303       a2 = c / e,
304       a3 = d / e;
305 
306     // Find a root for the following cubic equation:
307     //     y^3 - a2 y^2 + (a1 a3 - 4 a0) y + (4 a2 a0 - a1 ^2 - a3^2 a0) = 0
308     // aka Resolvent cubic
309     T b0 = 4 * a2 * a0 - a1 * a1 - a3 * a3 * a0;
310     T b1 = a1 * a3 - 4 * a0;
311     T b2 = -a2;
312     T[3] resolventCubicRoots;
313     int numRoots = solveCubic!T(b0, b1, b2, 1, resolventCubicRoots[]);
314     assert(numRoots == 3);
315     T y = resolventCubicRoots[0];
316     if (y < resolventCubicRoots[1]) y = resolventCubicRoots[1];
317     if (y < resolventCubicRoots[2]) y = resolventCubicRoots[2];
318 
319     // Compute R, D & E
320     T R = 0.25f * a3 * a3 - a2 + y;
321     if (R < 0.0)
322         return 0;
323     R = sqrt(R);
324 
325     T D = void,
326       E = void;
327     if (R == 0)
328     {
329         T d1 = 0.75f * a3 * a3 - 2 * a2;
330         T d2 = 2 * sqrt(y * y - 4 * a0);
331         D = sqrt(d1 + d2) * 0.5f;
332         E = sqrt(d1 - d2) * 0.5f;
333     }
334     else
335     {
336         T Rsquare = R * R;
337         T Rrec = 1 / R;
338         T d1 =  0.75f * a3 * a3 - Rsquare - 2 * a2;
339         T d2 = 0.25f * Rrec * (4 * a3 * a2 - 8 * a1 - a3 * a3 * a3);
340         D = sqrt(d1 + d2) * 0.5f;
341         E = sqrt(d1 - d2) * 0.5f;
342     }
343 
344     // Compute the 4 roots
345     a3 *= -0.25f;
346     R *= 0.5f;
347 
348     roots[0] = a3 + R + D;
349     roots[1] = a3 + R - D;
350     roots[2] = a3 - R + E;
351     roots[3] = a3 - R - E;
352     return 4;
353 }
354 
355 
356 unittest
357 {
358     bool arrayContainsRoot(double[] arr, double root)
359     {
360         foreach(e; arr)
361             if (abs(e - root) < 1e-7)
362                 return true;
363         return false;
364     }
365 
366     // test quadratic
367     {
368         double[3] roots;
369         int numRoots = solveCubic!double(-2, -3 / 2.0, 3 / 4.0, 1 / 4.0, roots[]);
370         assert(numRoots == 3);
371         assert(arrayContainsRoot(roots[], -4));
372         assert(arrayContainsRoot(roots[], -1));
373         assert(arrayContainsRoot(roots[], 2));
374     }
375 
376     // test quartic
377     {
378         double[4] roots;
379         int numRoots = solveQuartic!double(0, -2, -1, 2, 1, roots[]);
380 
381         assert(numRoots == 4);
382         assert(arrayContainsRoot(roots[], -2));
383         assert(arrayContainsRoot(roots[], -1));
384         assert(arrayContainsRoot(roots[], 0));
385         assert(arrayContainsRoot(roots[], 1));
386     }
387 }
388 
389 /// Arithmetic mean.
390 double average(R)(R r) if (isInputRange!R)
391 {
392     if (r.empty)
393         return double.nan;
394 
395     typeof(r.front()) sum = 0;
396     long count = 0;
397     foreach(e; r)
398     {
399         sum += e;
400         ++count;
401     }
402     return sum / count;
403 }
404 
405 /// Minimum of a range.
406 double minElement(R)(R r) if (isInputRange!R)
407 {
408     // do like Javascript for an empty range
409     if (r.empty)
410         return double.infinity;
411 
412     return minmax!("<", R)(r);
413 }
414 
415 /// Maximum of a range.
416 double maxElement(R)(R r) if (isInputRange!R)
417 {
418     // do like Javascript for an empty range
419     if (r.empty)
420         return -double.infinity;
421 
422     return minmax!(">", R)(r);
423 }
424 
425 /// Variance of a range.
426 double variance(R)(R r) if (isForwardRange!R)
427 {
428     if (r.empty)
429         return double.nan;
430 
431     auto avg = average(r.save); // getting the average
432 
433     typeof(avg) sum = 0;
434     long count = 0;
435     foreach(e; r)
436     {
437         sum += (e - avg) ^^ 2;
438         ++count;
439     }
440     if (count <= 1)
441         return 0.0;
442     else
443         return (sum / (count - 1.0)); // using sample std deviation as estimator
444 }
445 
446 /// Standard deviation of a range.
447 double standardDeviation(R)(R r) if (isForwardRange!R)
448 {
449     return sqrt(variance(r));
450 }
451 
452 private
453 {
454     typeof(R.front()) minmax(string op, R)(R r) if (isInputRange!R)
455     {
456         assert(!r.empty);
457         auto best = r.front();
458         r.popFront();
459         foreach(e; r)
460         {
461             mixin("if (e " ~ op ~ " best) best = e;");
462         }
463         return best;
464     }
465 }
466 
467 /// SSE approximation of reciprocal square root.
468 @nogc T inverseSqrt(T)(T x) pure nothrow if (isFloatingPoint!T)
469 {
470     static if (is(T == float))
471     {
472         __m128 V = _mm_set_ss(x);
473         V = _mm_rsqrt_ss(V);
474         return _mm_cvtss_f32(V);
475     }
476     else
477     {
478         return 1 / sqrt(x);
479     }
480 }
481 
482 unittest
483 {
484     assert(abs( inverseSqrt!float(1) - 1) < 1e-3 );
485     assert(abs( inverseSqrt!double(1) - 1) < 1e-3 );
486 }