1 /**
2   Provide a 2^N-bit integer type.
3   Guaranteed to never allocate and expected binary layout
4   Recursive implementation with very slow division.
5 
6   <b>Supports all operations that builtin integers support.</b>
7 
8   Bugs: it's not sure if the unsigned operand would take precedence in a comparison/division.
9   TODO: add literals.
10  */
11 module gfm.math.wideint;
12 
13 import std.traits,
14        std.ascii;
15 import std.format : FormatSpec;
16 
17 /// Wide signed integer.
18 /// Params:
19 ///    bits = number of bits, must be a power of 2.
20 template wideint(int bits)
21 {
22     alias integer!(true, bits) wideint;
23 }
24 
25 /// Wide unsigned integer.
26 /// Params:
27 ///    bits = number of bits, must be a power of 2.
28 template uwideint(int bits)
29 {
30     alias integer!(false, bits) uwideint;
31 }
32 
33 // Some predefined integers (any power of 2 greater than 128 would work)
34 
35 alias wideint!128 int128; // cent and ucent!
36 alias uwideint!128 uint128;
37 
38 alias wideint!256 int256;
39 alias uwideint!256 uint256;
40 
41 /// Use this template to get an arbitrary sized integer type.
42 private template integer(bool signed, int bits)
43     if ((bits & (bits - 1)) == 0)
44 {
45 
46     // forward to native type for lower numbers of bits
47     static if (bits == 8)
48     {
49         static if (signed)
50             alias byte integer;
51         else
52             alias ubyte integer;
53     }
54     else static if (bits == 16)
55     {
56         static if (signed)
57             alias short integer;
58         else
59             alias ushort integer;
60     }
61     else static if (bits == 32)
62     {
63         static if (signed)
64             alias int integer;
65         else
66             alias uint integer;
67     }
68     else static if (bits == 64)
69     {
70         static if (signed)
71             alias long integer;
72         else
73             alias ulong integer;
74     }
75     else
76     {
77         alias wideIntImpl!(signed, bits) integer;
78     }
79 }
80 
81 private template integer(bool signed, int bits)
82     if ((bits & (bits - 1)) != 0)
83 {
84     static assert(0, "wide integer bits must be a power of 2");
85 }
86 
87 /// Recursive 2^n integer implementation.
88 struct wideIntImpl(bool signed, int bits)
89 {
90     static assert(bits >= 128);
91     private
92     {
93         alias wideIntImpl self;
94 
95         template isSelf(T)
96         {
97             enum bool isSelf = is(Unqual!T == self);
98         }
99 
100         alias integer!(true, bits/2) sub_int_t;   // signed bits/2 integer
101         alias integer!(false, bits/2) sub_uint_t; // unsigned bits/2 integer
102 
103         alias integer!(true, bits/4) sub_sub_int_t;   // signed bits/4 integer
104         alias integer!(false, bits/4) sub_sub_uint_t; // unsigned bits/4 integer
105 
106         static if(signed)
107             alias sub_int_t hi_t; // hi_t has same signedness as the whole struct
108         else
109             alias sub_uint_t hi_t;
110 
111         alias sub_uint_t low_t;   // low_t is always unsigned
112 
113         enum _isWideIntImpl = true,
114              _bits = bits,
115              _signed = signed;
116     }
117 
118     /// Construct from a value.
119     @nogc this(T)(T x) pure nothrow
120     {
121         opAssign!T(x);
122     }
123 
124     // Private functions used by the `literal` template.
125     private static bool isValidDigitString(string digits)
126     {
127         import std.algorithm : startsWith;
128         import std.ascii : isDigit;
129 
130         if (digits.startsWith("0x"))
131         {
132             foreach (d; digits[2 .. $])
133             {
134                 if (!isHexDigit(d) && d != '_')
135                     return false;
136             }
137         }
138         else // decimal
139         {
140             foreach (d; digits)
141             {
142                 if (!isDigit(d) && d != '_')
143                     return false;
144             }
145         }
146         return true;
147     }
148 
149     private static typeof(this) literalImpl(string digits)
150     {
151         import std.algorithm : startsWith;
152         import std.ascii : isDigit;
153 
154         typeof(this) value = 0;
155         if (digits.startsWith("0x"))
156         {
157             foreach (d; digits[2 .. $])
158             {
159                 if (d == '_')
160                     continue;
161                 value <<= 4;
162                 if (isDigit(d))
163                     value += d - '0';
164                 else
165                     value += 10 + toUpper(d) - 'A';
166             }
167         }
168         else
169         {
170             foreach (d; digits)
171             {
172                 if (d == '_')
173                     continue;
174                 value *= 10;
175                 value += d - '0';
176             }
177         }
178         return value;
179     }
180 
181     /// Construct from compile-time digit string.
182     ///
183     /// Both decimal and hex digit strings are supported.
184     ///
185     /// Example:
186     /// ----
187     /// auto x = int128.literal!"20_000_000_000_000_000_001";
188     /// assert((x >>> 1) == 0x8AC7_2304_89E8_0000);
189     ///
190     /// auto y = int126.literal!"0x1_158E_4609_13D0_0001";
191     /// assert(y == x);
192     /// ----
193     template literal(string digits)
194     {
195         static assert(isValidDigitString(digits),
196                       "invalid digits in literal: " ~ digits);
197         enum literal = literalImpl(digits);
198     }
199 
200     /// Assign with a smaller unsigned type.
201     @nogc ref self opAssign(T)(T n) pure nothrow if (isIntegral!T && isUnsigned!T)
202     {
203         hi = 0;
204         lo = n;
205         return this;
206     }
207 
208     /// Assign with a smaller signed type (sign is extended).
209     @nogc ref self opAssign(T)(T n) pure nothrow if (isIntegral!T && isSigned!T)
210     {
211         // shorter int always gets sign-extended,
212         // regardless of the larger int being signed or not
213         hi = (n < 0) ? cast(hi_t)(-1) : cast(hi_t)0;
214 
215         // will also sign extend as well if needed
216         lo = cast(sub_int_t)n;
217         return this;
218     }
219 
220     /// Assign with a wide integer of the same size (sign is lost).
221     @nogc ref self opAssign(T)(T n) pure nothrow if (is(typeof(T._isWideIntImpl)) && T._bits == bits)
222     {
223         hi = n.hi;
224         lo = n.lo;
225         return this;
226     }
227 
228     /// Assign with a smaller wide integer (sign is extended accordingly).
229     @nogc ref self opAssign(T)(T n) pure nothrow if (is(typeof(T._isWideIntImpl)) && T._bits < bits)
230     {
231         static if (T._signed)
232         {
233             // shorter int always gets sign-extended,
234             // regardless of the larger int being signed or not
235             hi = cast(hi_t)((n < 0) ? -1 : 0);
236 
237             // will also sign extend as well if needed
238             lo = cast(sub_int_t)n;
239             return this;
240         }
241         else
242         {
243             hi = 0;
244             lo = n;
245             return this;
246         }
247     }
248 
249     /// Cast to a smaller integer type (truncation).
250     @nogc T opCast(T)() pure const nothrow if (isIntegral!T)
251     {
252         return cast(T)lo;
253     }
254 
255     /// Cast to bool.
256     @nogc T opCast(T)() pure const nothrow if (is(T == bool))
257     {
258         return this != 0;
259     }
260 
261     /// Cast to wide integer of any size.
262     @nogc T opCast(T)() pure const nothrow if (is(typeof(T._isWideIntImpl)))
263     {
264         static if (T._bits < bits)
265             return cast(T)lo;
266         else
267             return T(this);
268     }
269 
270     /// Converts to a string. Supports format specifiers %d, %s (both decimal)
271     /// and %x (hex).
272     void toString(DG, Char)(DG sink, FormatSpec!Char fmt) const
273         if (is(typeof(sink((const(Char)[]).init))))
274     {
275         if (fmt.spec == 'x')
276         {
277             enum hexdigits = bits / 8;
278             Char[1] buf;
279 
280             sink("0x");
281             for (int i = 0; i < hexdigits; ++i)
282             {
283                 buf[0] = hexDigits[cast(int)((hi >> ((15 - i) * 4)) & 15)];
284                 sink(buf[]);
285             }
286             for (int i = 0; i < hexdigits; ++i)
287             {
288                 buf[0] = hexDigits[cast(int)((lo >> ((15 - i) * 4)) & 15)];
289                 sink(buf[]);
290             }
291         }
292         else // default to decimal
293         {
294             import std.algorithm : reverse;
295 
296             if (this == 0)
297             {
298                 sink("0");
299                 return;
300             }
301 
302             // The maximum number of decimal digits is basically
303             // ceil(log_10(2^^bits - 1)), which is slightly below
304             // ceil(bits * log(2)/log(10)). The value 0.30103 is a slight
305             // overestimate of log(2)/log(10), to be sure we never
306             // underestimate. We add 1 to account for rounding up.
307             enum maxDigits = cast(ulong)(0.30103 * bits) + 1;
308             Char[maxDigits] buf;
309             size_t i;
310 
311             wideIntImpl tmp = this;
312             if (tmp < 0)
313             {
314                 sink("-");
315                 tmp = -tmp;
316             }
317             for (i = maxDigits-1; tmp > 0; i--)
318             {
319                 assert(i > 0);
320                 buf[i] = cast(Char)('0' + cast(int)(tmp % 10));
321                 tmp /= 10;
322             }
323             assert(i+1 >= 0);
324             sink(buf[i+1 .. $]);
325         }
326     }
327 
328     @nogc self opBinary(string op, T)(T o) pure const nothrow if (!isSelf!T)
329     {
330         self r = this;
331         self y = o;
332         return r.opOpAssign!(op)(y);
333     }
334 
335     @nogc self opBinary(string op, T)(T y) pure const nothrow if (isSelf!T)
336     {
337         self r = this; // copy
338         self o = y;
339         return r.opOpAssign!(op)(o);
340     }
341 
342     @nogc ref self opOpAssign(string op, T)(T y) pure nothrow if (!isSelf!T)
343     {
344         const(self) o = y;
345         return opOpAssign!(op)(o);
346     }
347 
348     @nogc ref self opOpAssign(string op, T)(T y) pure nothrow if (isSelf!T)
349     {
350         static if (op == "+")
351         {
352             hi += y.hi;
353             if (lo + y.lo < lo) // deal with overflow
354                 ++hi;
355             lo += y.lo;
356         }
357         else static if (op == "-")
358         {
359             opOpAssign!"+"(-y);
360         }
361         else static if (op == "<<")
362         {
363             if (y >= bits)
364             {
365                 hi = 0;
366                 lo = 0;
367             }
368             else if (y >= bits / 2)
369             {
370                 hi = lo << (y.lo - bits / 2);
371                 lo = 0;
372             }
373             else if (y > 0)
374             {
375                 hi = (lo >>> (-y.lo + bits / 2)) | (hi << y.lo);
376                 lo = lo << y.lo;
377             }
378         }
379         else static if (op == ">>" || op == ">>>")
380         {
381             assert(y >= 0);
382             static if (!signed || op == ">>>")
383                 immutable(sub_int_t) signFill = 0;
384             else
385                 immutable(sub_int_t) signFill = cast(sub_int_t)(isNegative() ? -1 : 0);
386 
387             if (y >= bits)
388             {
389                 hi = signFill;
390                 lo = signFill;
391             }
392             else if (y >= bits/2)
393             {
394                 lo = hi >> (y.lo - bits/2);
395                 hi = signFill;
396             }
397             else if (y > 0)
398             {
399                 lo = (hi << (-y.lo + bits/2)) | (lo >> y.lo);
400                 hi = hi >> y.lo;
401             }
402         }
403         else static if (op == "*")
404         {
405             sub_sub_uint_t[4] a = toParts();
406             sub_sub_uint_t[4] b = y.toParts();
407 
408             this = 0;
409             for(int i = 0; i < 4; ++i)
410                 for(int j = 0; j < 4 - i; ++j)
411                     this += self(cast(sub_uint_t)(a[i]) * b[j]) << ((bits/4) * (i + j));
412         }
413         else static if (op == "&")
414         {
415             hi &= y.hi;
416             lo &= y.lo;
417         }
418         else static if (op == "|")
419         {
420             hi |= y.hi;
421             lo |= y.lo;
422         }
423         else static if (op == "^")
424         {
425             hi ^= y.hi;
426             lo ^= y.lo;
427         }
428         else static if (op == "/" || op == "%")
429         {
430             self q = void, r = void;
431             static if(signed)
432                 Internals!bits.signedDivide(this, y, q, r);
433             else
434                 Internals!bits.unsignedDivide(this, y, q, r);
435             static if (op == "/")
436                 this = q;
437             else
438                 this = r;
439         }
440         else
441         {
442             static assert(false, "unsupported operation '" ~ op ~ "'");
443         }
444         return this;
445     }
446 
447     // const unary operations
448     @nogc self opUnary(string op)() pure const nothrow if (op == "+" || op == "-" || op == "~")
449     {
450         static if (op == "-")
451         {
452             self r = this;
453             r.not();
454             r.increment();
455             return r;
456         }
457         else static if (op == "+")
458            return this;
459         else static if (op == "~")
460         {
461             self r = this;
462             r.not();
463             return r;
464         }
465     }
466 
467     // non-const unary operations
468     @nogc self opUnary(string op)() pure nothrow if (op == "++" || op == "--")
469     {
470         static if (op == "++")
471             increment();
472         else static if (op == "--")
473             decrement();
474         return this;
475     }
476 
477     @nogc bool opEquals(T)(T y) pure const if (!isSelf!T)
478     {
479         return this == self(y);
480     }
481 
482     @nogc bool opEquals(T)(T y) pure const if (isSelf!T)
483     {
484        return lo == y.lo && y.hi == hi;
485     }
486 
487     @nogc int opCmp(T)(T y) pure const if (!isSelf!T)
488     {
489         return opCmp(self(y));
490     }
491 
492     @nogc int opCmp(T)(T y) pure const if (isSelf!T)
493     {
494         if (hi < y.hi) return -1;
495         if (hi > y.hi) return 1;
496         if (lo < y.lo) return -1;
497         if (lo > y.lo) return 1;
498         return 0;
499     }
500 
501     // binary layout should be what is expected on this platform
502     version (LittleEndian)
503     {
504         low_t lo;
505         hi_t hi;
506     }
507     else
508     {
509         hi_t hi;
510         low_t lo;
511     }
512 
513     private
514     {
515         static if (signed)
516         {
517             @nogc bool isNegative() pure nothrow const
518             {
519                 return signBit();
520             }
521         }
522         else
523         {
524             @nogc bool isNegative() pure nothrow const
525             {
526                 return false;
527             }
528         }
529 
530         @nogc void not() pure nothrow
531         {
532             hi = ~hi;
533             lo = ~lo;
534         }
535 
536         @nogc void increment() pure nothrow
537         {
538             ++lo;
539             if (lo == 0) ++hi;
540         }
541 
542         @nogc void decrement() pure nothrow
543         {
544             if (lo == 0) --hi;
545             --lo;
546         }
547 
548         @nogc bool signBit() pure const nothrow
549         {
550             enum SIGN_SHIFT = bits / 2 - 1;
551             return ((hi >> SIGN_SHIFT) & 1) != 0;
552         }
553 
554         @nogc sub_sub_uint_t[4] toParts() pure const nothrow
555         {
556             sub_sub_uint_t[4] p = void;
557             enum SHIFT = bits / 4;
558             immutable lomask = cast(sub_uint_t)(cast(sub_sub_int_t)(-1));
559             p[3] = cast(sub_sub_uint_t)(hi >> SHIFT);
560             p[2] = cast(sub_sub_uint_t)(hi & lomask);
561             p[1] = cast(sub_sub_uint_t)(lo >> SHIFT);
562             p[0] = cast(sub_sub_uint_t)(lo & lomask);
563             return p;
564         }
565     }
566 }
567 
568 @nogc public wideIntImpl!(signed, bits) abs(bool signed, int bits)(wideIntImpl!(signed, bits) x) pure nothrow
569 {
570     if(x >= 0)
571         return x;
572     else
573         return -x;
574 }
575 
576 private struct Internals(int bits)
577 {
578     alias wideIntImpl!(true, bits) wint_t;
579     alias wideIntImpl!(false, bits) uwint_t;
580 
581     @nogc static void unsignedDivide(uwint_t dividend, uwint_t divisor,
582                                      out uwint_t quotient, out uwint_t remainder) pure nothrow
583     {
584         assert(divisor != 0);
585 
586         uwint_t rQuotient = 0;
587         uwint_t cDividend = dividend;
588 
589         while (divisor <= cDividend)
590         {
591             // find N so that (divisor << N) <= cDividend && cDividend < (divisor << (N + 1) )
592 
593             uwint_t N = 0;
594             uwint_t cDivisor = divisor;
595             while (cDividend > cDivisor)
596             {
597                 if (cDivisor.signBit())
598                     break;
599 
600                 if (cDividend < (cDivisor << 1))
601                     break;
602 
603                 cDivisor <<= 1;
604                 ++N;
605             }
606             cDividend = cDividend - cDivisor;
607             rQuotient += (uwint_t(1) << N);
608         }
609 
610         quotient = rQuotient;
611         remainder = cDividend;
612     }
613 
614     @nogc static void signedDivide(wint_t dividend, wint_t divisor,
615                                    out wint_t quotient, out wint_t remainder) pure nothrow
616     {
617         uwint_t q, r;
618         unsignedDivide(uwint_t(abs(dividend)), uwint_t(abs(divisor)), q, r);
619 
620         // remainder has same sign as the dividend
621         if (dividend < 0)
622             r = -r;
623 
624         // negate the quotient if opposite signs
625         if ((dividend >= 0) != (divisor >= 0))
626             q = -q;
627 
628         quotient = q;
629         remainder = r;
630 
631         assert(remainder == 0 || ((remainder < 0) == (dividend < 0)));
632     }
633 }
634 
635 // Verify that toString is callable from pure / nothrow / @nogc code as long as
636 // the callback also has these attributes.
637 pure nothrow @nogc unittest
638 {
639     int256 x = 123;
640     FormatSpec!char fspec;
641 
642     fspec.spec = 's';
643     x.toString((const(char)[]) {}, fspec);
644 
645     // Verify that wide strings actually work
646     FormatSpec!dchar dfspec;
647     dfspec.spec = 's';
648     x.toString((const(dchar)[] x) { assert(x == "123"); }, dfspec);
649 }
650 
651 unittest
652 {
653     import std..string : format;
654 
655     int128 x;
656     x.hi = 1;
657     x.lo = 0x158E_4609_13D0_0001;
658     assert(format("%s", x) == "20000000000000000001");
659     assert(format("%d", x) == "20000000000000000001");
660     assert(format("%x", x) == "0x0000000000000001158E460913D00001");
661 
662     x.hi = 0xFFFF_FFFF_FFFF_FFFE;
663     x.lo = 0xEA71_B9F6_EC2F_FFFF;
664     assert(format("%d", x) == "-20000000000000000001");
665     assert(format("%x", x) == "0xFFFFFFFFFFFFFFFEEA71B9F6EC2FFFFF");
666 
667     x.hi = x.lo = 0;
668     assert(format("%d", x) == "0");
669 
670     x.hi = x.lo = 0xFFFF_FFFF_FFFF_FFFF;
671     assert(format("%d", x) == "-1"); // array index boundary condition
672 }
673 
674 unittest
675 {
676     long step = 164703072086692425;
677     for (long si = long.min; si <= long.max - step; si += step)
678     {
679         for (long sj = long.min; sj <= long.max - step; sj += step)
680         {
681             ulong ui = cast(ulong)si;
682             ulong uj = cast(ulong)sj;
683             int128 csi = si;
684             uint128 cui = si;
685             int128 csj = sj;
686             uint128 cuj = sj;
687             assert(csi == csi);
688             assert(~~csi == csi);
689             assert(-(-csi) == csi);
690             assert(++csi == si + 1);
691             assert(--csi == si);
692 
693             string testSigned(string op)
694             {
695                 return "assert(cast(ulong)(si" ~ op ~ "sj) == cast(ulong)(csi" ~ op ~ "csj));";
696             }
697 
698             string testMixed(string op)
699             {
700                 return "assert(cast(ulong)(ui" ~ op ~ "sj) == cast(ulong)(cui" ~ op ~ "csj));"
701                      ~ "assert(cast(ulong)(si" ~ op ~ "uj) == cast(ulong)(csi" ~ op ~ "cuj));";
702             }
703 
704             string testUnsigned(string op)
705             {
706                 return "assert(cast(ulong)(ui" ~ op ~ "uj) == cast(ulong)(cui" ~ op ~ "cuj));";
707             }
708 
709             string testAll(string op)
710             {
711                 return testSigned(op) ~ testMixed(op) ~ testUnsigned(op);
712             }
713 
714             mixin(testAll("+"));
715             mixin(testAll("-"));
716             mixin(testAll("*"));
717             mixin(testAll("|"));
718             mixin(testAll("&"));
719             mixin(testAll("^"));
720             if (sj != 0)
721             {
722                 mixin(testSigned("/"));
723                 mixin(testSigned("%"));
724                 if (si >= 0 && sj >= 0)
725                 {
726                     // those operations are not supposed to be the same at
727                     // higher bitdepth: a sign-extended negative may yield higher dividend
728                     testMixed("/");
729                     testUnsigned("/");
730                     testMixed("%");
731                     testUnsigned("%");
732                 }
733             }
734         }
735     }
736 }
737 
738 unittest
739 {
740     // Just a little over 2^64, so it actually needs int128.
741     // Hex value should be 0x1_158E_4609_13D0_0001.
742     enum x = int128.literal!"20_000_000_000_000_000_001";
743     assert(x.hi == 0x1 && x.lo == 0x158E_4609_13D0_0001);
744     assert((x >>> 1) == 0x8AC7_2304_89E8_0000);
745 
746     enum y = int128.literal!"0x1_158E_4609_13D0_0001";
747     enum z = int128.literal!"0x1_158e_4609_13d0_0001"; // case insensitivity
748     assert(x == y && y == z && x == z);
749 }