1 /**
2   Useful math functions and range-based statistic computations.
3 
4   If you need real statistics, consider using the $(WEB github.com/dsimcha/dstats,Dstats) library.
5  */
6 module gfm.math.funcs;
7 
8 import std.math,
9        std.traits,
10        std.range,
11        std.math;
12 
13 version( D_InlineAsm_X86 )
14 {
15     version = AsmX86;
16 }
17 else version( D_InlineAsm_X86_64 )
18 {
19     version = AsmX86;
20 }
21 
22 /// Convert from radians to degrees.
23 @nogc T degrees(T)(T x) pure nothrow if (!isIntegral!T)
24 {
25     return x * (180 / PI);
26 }
27 
28 /// Convert from degrees to radians.
29 @nogc T radians(T)(T x) pure nothrow if (!isIntegral!T)
30 {
31     return x * (PI / 180);
32 }
33 
34 /// Linear interpolation, akin to GLSL's mix.
35 @nogc S lerp(S, T)(S a, S b, T t) pure nothrow
36   if (is(typeof(t * b + (1 - t) * a) : S))
37 {
38     return t * b + (1 - t) * a;
39 }
40 
41 /// Clamp x in [min, max], akin to GLSL's clamp.
42 @nogc T clamp(T)(T x, T min, T max) pure nothrow
43 {
44     if (x < min)
45         return min;
46     else if (x > max)
47         return max;
48     else
49         return x;
50 }
51 
52 /// Integer truncation.
53 @nogc long ltrunc(real x) nothrow // may be pure but trunc isn't pure
54 {
55     return cast(long)(trunc(x));
56 }
57 
58 /// Integer flooring.
59 @nogc long lfloor(real x) nothrow // may be pure but floor isn't pure
60 {
61     return cast(long)(floor(x));
62 }
63 
64 /// Returns: Fractional part of x.
65 @nogc T fract(T)(real x) nothrow
66 {
67     return x - lfloor(x);
68 }
69 
70 /// Safe asin: input clamped to [-1, 1]
71 @nogc T safeAsin(T)(T x) pure nothrow
72 {
73     return asin(clamp!T(x, -1, 1));
74 }
75 
76 /// Safe acos: input clamped to [-1, 1]
77 @nogc T safeAcos(T)(T x) pure nothrow
78 {
79     return acos(clamp!T(x, -1, 1));
80 }
81 
82 /// Same as GLSL step function.
83 /// 0.0 is returned if x < edge, and 1.0 is returned otherwise.
84 @nogc T step(T)(T edge, T x) pure nothrow
85 {
86     return (x < edge) ? 0 : 1;
87 }
88 
89 /// Same as GLSL smoothstep function.
90 /// See: http://en.wikipedia.org/wiki/Smoothstep
91 @nogc T smoothStep(T)(T a, T b, T t) pure nothrow
92 {
93     if (t <= a)
94         return 0;
95     else if (t >= b)
96         return 1;
97     else
98     {
99         T x = (t - a) / (b - a);
100         return x * x * (3 - 2 * x);
101     }
102 }
103 
104 /// Returns: true of i is a power of 2.
105 @nogc bool isPowerOf2(T)(T i) pure nothrow if (isIntegral!T)
106 {
107     assert(i >= 0);
108     return (i != 0) && ((i & (i - 1)) == 0);
109 }
110 
111 /// Integer log2
112 /// TODO: use bt intrinsics
113 @nogc int ilog2(T)(T i) nothrow if (isIntegral!T)
114 {
115     assert(i > 0);
116     assert(isPowerOf2(i));
117     int result = 0;
118     while (i > 1)
119     {
120         i = i / 2;
121         result = result + 1;
122     }
123     return result;
124 }
125 
126 /// Computes next power of 2.
127 @nogc int nextPowerOf2(int i) pure nothrow
128 {
129     int v = i - 1;
130     v |= v >> 1;
131     v |= v >> 2;
132     v |= v >> 4;
133     v |= v >> 8;
134     v |= v >> 16;
135     v++;
136     assert(isPowerOf2(v));
137     return v;
138 }
139 
140 /// Computes next power of 2.
141 @nogc long nextPowerOf2(long i) pure nothrow
142 {
143     long v = i - 1;
144     v |= v >> 1;
145     v |= v >> 2;
146     v |= v >> 4;
147     v |= v >> 8;
148     v |= v >> 16;
149     v |= v >> 32;
150     v++;
151     assert(isPowerOf2(v));
152     return v;
153 }
154 
155 /// Computes sin(x)/x accurately.
156 /// See_also: $(WEB www.plunk.org/~hatch/rightway.php)
157 @nogc T sinOverX(T)(T x) pure nothrow
158 {
159     if (1 + x * x == 1)
160         return 1;
161     else
162         return sin(x) / x;
163 }
164 
165 
166 /// Signed integer modulo a/b where the remainder is guaranteed to be in [0..b[,
167 /// even if a is negative. Only support positive dividers.
168 @nogc T moduloWrap(T)(T a, T b) pure nothrow if (isSigned!T)
169 in
170 {
171     assert(b > 0);
172 }
173 body
174 {
175     if (a >= 0)
176         a = a % b;
177     else
178     {
179         auto rem = a % b;
180         x = (rem == 0) ? 0 : (-rem + b);
181     }
182 
183     assert(x >= 0 && x < b);
184     return x;
185 }
186 
187 unittest
188 {
189     assert(nextPowerOf2(13) == 16);
190 }
191 
192 /**
193  * Find the root of a linear polynomial a + b x = 0
194  * Returns: Number of roots.
195  */
196 @nogc int solveLinear(T)(T a, T b, out T root) pure nothrow if (isFloatingPoint!T)
197 {
198     if (b == 0)
199     {
200         return 0;
201     }
202     else
203     {
204         root = -a / b;
205         return 1;
206     }
207 }
208 
209 
210 /**
211  * Finds the root roots of a quadratic polynomial a + b x + c x^2 = 0
212  * Params:
213  *     outRoots = array of root results, should have room for at least 2 elements.
214  * Returns: Number of roots in outRoots.
215  */
216 @nogc int solveQuadratic(T)(T a, T b, T c, T[] outRoots) pure nothrow if (isFloatingPoint!T)
217 {
218     assert(outRoots.length >= 2);
219     if (c == 0)
220         return solveLinear(a, b, outRoots[0]);
221 
222     T delta = b * b - 4 * a * c;
223     if (delta < 0.0 )
224         return 0;
225 
226     delta = sqrt(delta);
227     T oneOver2a = 0.5 / a;
228 
229     outRoots[0] = oneOver2a * (-b - delta);
230     outRoots[1] = oneOver2a * (-b + delta);
231     return 2;
232 }
233 
234 
235 /**
236  * Finds the roots of a cubic polynomial  a + b x + c x^2 + d x^3 = 0
237  * Params:
238  *     outRoots = array of root results, should have room for at least 2 elements.
239  * Returns: Number of roots in outRoots.
240  * See_also: $(WEB www.codeguru.com/forum/archive/index.php/t-265551.html)
241  */
242 @nogc int solveCubic(T)(T a, T b, T c, T d, T[] outRoots) pure nothrow if (isFloatingPoint!T)
243 {
244     assert(outRoots.length >= 3);
245     if (d == 0)
246         return solveQuadratic(a, b, c, outRoots);
247 
248     // adjust coefficients
249     T a1 = c / d,
250       a2 = b / d,
251       a3 = a / d;
252 
253     T Q = (a1 * a1 - 3 * a2) / 9,
254       R = (2 * a1 * a1 * a1 - 9 * a1 * a2 + 27 * a3) / 54;
255 
256     T Qcubed = Q * Q * Q;
257     T d2 = Qcubed - R * R;
258 
259     if (d2 >= 0)
260     {
261         // 3 real roots
262         if (Q < 0.0)
263             return 0;
264         T P = R / sqrt(Qcubed);
265 
266         assert(-1 <= P && P <= 1);
267         T theta = acos(P);
268         T sqrtQ = sqrt(Q);
269 
270         outRoots[0] = -2 * sqrtQ * cos(theta / 3) - a1 / 3;
271         outRoots[1] = -2 * sqrtQ * cos((theta + 2 * PI) / 3) - a1 / 3;
272         outRoots[2] = -2 * sqrtQ * cos((theta + 4 * PI) / 3) - a1 / 3;
273         return 3;
274     }
275     else
276     {
277         // 1 real root
278         T e = (sqrt(-d) + abs(R)) ^^ cast(T)(1.0 / 3.0);
279         if (R > 0)
280             e = -e;
281         outRoots[0] = e + Q / e - a1 / 3.0;
282         return 1;
283     }
284 }
285 
286 /**
287  * Returns the roots of a quartic polynomial  a + b x + c x^2 + d x^3 + e x^4 = 0
288  *
289  * Returns number of roots. roots slice should have room for up to 4 elements.
290  * Bugs: doesn't pass unit-test!
291  * See_also: $(WEB mathworld.wolfram.com/QuarticEquation.html)
292  */
293 @nogc int solveQuartic(T)(T a, T b, T c, T d, T e, T[] roots) pure nothrow if (isFloatingPoint!T)
294 {
295     assert(roots.length >= 4);
296 
297     if (e == 0)
298         return solveCubic(a, b, c, d, roots);
299 
300     // Adjust coefficients
301     T a0 = a / e,
302       a1 = b / e,
303       a2 = c / e,
304       a3 = d / e;
305 
306     // Find a root for the following cubic equation:
307     //     y^3 - a2 y^2 + (a1 a3 - 4 a0) y + (4 a2 a0 - a1 ^2 - a3^2 a0) = 0
308     // aka Resolvent cubic
309     T b0 = 4 * a2 * a0 - a1 * a1 - a3 * a3 * a0;
310     T b1 = a1 * a3 - 4 * a0;
311     T b2 = -a2;
312     T[3] resolventCubicRoots;
313     int numRoots = solveCubic!T(b0, b1, b2, 1, resolventCubicRoots[]);
314     assert(numRoots == 3);
315     T y = resolventCubicRoots[0];
316     if (y < resolventCubicRoots[1]) y = resolventCubicRoots[1];
317     if (y < resolventCubicRoots[2]) y = resolventCubicRoots[2];
318 
319     // Compute R, D & E
320     T R = 0.25f * a3 * a3 - a2 + y;
321     if (R < 0.0)
322         return 0;
323     R = sqrt(R);
324 
325     T D = void,
326       E = void;
327     if (R == 0)
328     {
329         T d1 = 0.75f * a3 * a3 - 2 * a2;
330         T d2 = 2 * sqrt(y * y - 4 * a0);
331         D = sqrt(d1 + d2) * 0.5f;
332         E = sqrt(d1 - d2) * 0.5f;
333     }
334     else
335     {
336         T Rsquare = R * R;
337         T Rrec = 1 / R;
338         T d1 =  0.75f * a3 * a3 - Rsquare - 2 * a2;
339         T d2 = 0.25f * Rrec * (4 * a3 * a2 - 8 * a1 - a3 * a3 * a3);
340         D = sqrt(d1 + d2) * 0.5f;
341         E = sqrt(d1 - d2) * 0.5f;
342     }
343 
344     // Compute the 4 roots
345     a3 *= -0.25f;
346     R *= 0.5f;
347 
348     roots[0] = a3 + R + D;
349     roots[1] = a3 + R - D;
350     roots[2] = a3 - R + E;
351     roots[3] = a3 - R - E;
352     return 4;
353 }
354 
355 
356 unittest
357 {
358     bool arrayContainsRoot(double[] arr, double root)
359     {
360         foreach(e; arr)
361             if (abs(e - root) < 1e-7)
362                 return true;
363         return false;
364     }
365 
366     // test quadratic
367     {
368         double[3] roots;
369         int numRoots = solveCubic!double(-2, -3 / 2.0, 3 / 4.0, 1 / 4.0, roots[]);
370         assert(numRoots == 3);
371         assert(arrayContainsRoot(roots[], -4));
372         assert(arrayContainsRoot(roots[], -1));
373         assert(arrayContainsRoot(roots[], 2));
374     }
375 
376     // test quartic
377     {
378         double[4] roots;
379         int numRoots = solveQuartic!double(0, -2, -1, 2, 1, roots[]);
380 
381         assert(numRoots == 4);
382         assert(arrayContainsRoot(roots[], -2));
383         assert(arrayContainsRoot(roots[], -1));
384         assert(arrayContainsRoot(roots[], 0));
385         assert(arrayContainsRoot(roots[], 1));
386     }
387 }
388 
389 /// Arithmetic mean.
390 double average(R)(R r) if (isInputRange!R)
391 {
392     if (r.empty)
393         return double.nan;
394 
395     typeof(r.front()) sum = 0;
396     long count = 0;
397     foreach(e; r)
398     {
399         sum += e;
400         ++count;
401     }
402     return sum / count;
403 }
404 
405 /// Minimum of a range.
406 double minElement(R)(R r) if (isInputRange!R)
407 {
408     // do like Javascript for an empty range
409     if (r.empty)
410         return double.infinity;
411 
412     return minmax!("<", R)(r);
413 }
414 
415 /// Maximum of a range.
416 double maxElement(R)(R r) if (isInputRange!R)
417 {
418     // do like Javascript for an empty range
419     if (r.empty)
420         return -double.infinity;
421 
422     return minmax!(">", R)(r);
423 }
424 
425 /// Variance of a range.
426 double variance(R)(R r) if (isForwardRange!R)
427 {
428     if (r.empty)
429         return double.nan;
430 
431     auto avg = average(r.save); // getting the average
432 
433     typeof(avg) sum = 0;
434     long count = 0;
435     foreach(e; r)
436     {
437         sum += (e - avg) ^^ 2;
438         ++count;
439     }
440     if (count <= 1)
441         return 0.0;
442     else
443         return (sum / (count - 1.0)); // using sample std deviation as estimator
444 }
445 
446 /// Standard deviation of a range.
447 double standardDeviation(R)(R r) if (isForwardRange!R)
448 {
449     return sqrt(variance(r));
450 }
451 
452 private
453 {
454     typeof(R.front()) minmax(string op, R)(R r) if (isInputRange!R)
455     {
456         assert(!r.empty);
457         auto best = r.front();
458         r.popFront();
459         foreach(e; r)
460         {
461             mixin("if (e " ~ op ~ " best) best = e;");
462         }
463         return best;
464     }
465 }
466 
467 /// SSE approximation of reciprocal square root.
468 @nogc T inverseSqrt(T)(T x) pure nothrow if (isFloatingPoint!T)
469 {
470     version(AsmX86)
471     {
472         static if (is(T == float))
473         {
474             float result;
475 
476             asm pure nothrow @nogc 
477             {
478                 movss XMM0, x; 
479                 rsqrtss XMM0, XMM0; 
480                 movss result, XMM0; 
481             }
482             return result;
483         }
484         else
485             return 1 / sqrt(x);
486     }
487     else
488         return 1 / sqrt(x);
489 }
490 
491 unittest
492 {
493     assert(abs( inverseSqrt!float(1) - 1) < 1e-3 );
494     assert(abs( inverseSqrt!double(1) - 1) < 1e-3 );
495 }