1 /**
2   Provide a 2^N-bit integer type.
3   Guaranteed to never allocate and expected binary layout
4   Recursive implementation with very slow division.
5 
6   <b>Supports all operations that builtin integers support.</b>
7 
8   Bugs: it's not sure if the unsigned operand would take precedence in a comparison/division.
9           - a < b should be an unsigned comparison if at least one operand is unsigned
10           - a / b should be an unsigned division   if at least one operand is unsigned
11  */
12 module gfm.integers.wideint;
13 
14 import std.traits,
15        std.ascii;
16 import std.format : FormatSpec;
17 
18 /// Wide signed integer.
19 /// Params:
20 ///    bits = number of bits, must be a power of 2.
21 template wideint(int bits)
22 {
23     alias integer!(true, bits) wideint;
24 }
25 
26 /// Wide unsigned integer.
27 /// Params:
28 ///    bits = number of bits, must be a power of 2.
29 template uwideint(int bits)
30 {
31     alias integer!(false, bits) uwideint;
32 }
33 
34 // Some predefined integers (any power of 2 greater than 128 would work)
35 
36 alias wideint!128 int128; // cent and ucent!
37 alias uwideint!128 uint128;
38 
39 alias wideint!256 int256;
40 alias uwideint!256 uint256;
41 
42 /// Use this template to get an arbitrary sized integer type.
43 private template integer(bool signed, int bits)
44     if ((bits & (bits - 1)) == 0)
45 {
46 
47     // forward to native type for lower numbers of bits
48     static if (bits == 8)
49     {
50         static if (signed)
51             alias byte integer;
52         else
53             alias ubyte integer;
54     }
55     else static if (bits == 16)
56     {
57         static if (signed)
58             alias short integer;
59         else
60             alias ushort integer;
61     }
62     else static if (bits == 32)
63     {
64         static if (signed)
65             alias int integer;
66         else
67             alias uint integer;
68     }
69     else static if (bits == 64)
70     {
71         static if (signed)
72             alias long integer;
73         else
74             alias ulong integer;
75     }
76     else
77     {
78         alias wideIntImpl!(signed, bits) integer;
79     }
80 }
81 
82 private template integer(bool signed, int bits)
83     if ((bits & (bits - 1)) != 0)
84 {
85     static assert(0, "wide integer bits must be a power of 2");
86 }
87 
88 /// Recursive 2^n integer implementation.
89 struct wideIntImpl(bool signed, int bits)
90 {
91     static assert(bits >= 128);
92     private
93     {
94         alias wideIntImpl self;
95 
96         template isSelf(T)
97         {
98             enum bool isSelf = is(Unqual!T == self);
99         }
100 
101         alias integer!(true, bits/2) sub_int_t;   // signed bits/2 integer
102         alias integer!(false, bits/2) sub_uint_t; // unsigned bits/2 integer
103 
104         alias integer!(true, bits/4) sub_sub_int_t;   // signed bits/4 integer
105         alias integer!(false, bits/4) sub_sub_uint_t; // unsigned bits/4 integer
106 
107         static if(signed)
108             alias sub_int_t hi_t; // hi_t has same signedness as the whole struct
109         else
110             alias sub_uint_t hi_t;
111 
112         alias sub_uint_t low_t;   // low_t is always unsigned
113 
114         enum _bits = bits,
115              _signed = signed;
116     }
117 
118     /// Construct from a value.
119     @nogc this(T)(T x) pure nothrow
120     {
121         opAssign!T(x);
122     }
123 
124     // Private functions used by the `literal` template.
125     private static bool isValidDigitString(string digits)
126     {
127         import std.algorithm : startsWith;
128         import std.ascii : isDigit;
129 
130         if (digits.startsWith("0x"))
131         {
132             foreach (d; digits[2 .. $])
133             {
134                 if (!isHexDigit(d) && d != '_')
135                     return false;
136             }
137         }
138         else // decimal
139         {
140             static if (signed)
141                 if (digits.startsWith("-"))
142                     digits = digits[1 .. $];
143             if (digits.length < 1)
144                 return false;   // at least 1 digit required
145             foreach (d; digits)
146             {
147                 if (!isDigit(d) && d != '_')
148                     return false;
149             }
150         }
151         return true;
152     }
153 
154     private static typeof(this) literalImpl(string digits)
155     {
156         import std.algorithm : startsWith;
157         import std.ascii : isDigit;
158 
159         typeof(this) value = 0;
160         if (digits.startsWith("0x"))
161         {
162             foreach (d; digits[2 .. $])
163             {
164                 if (d == '_')
165                     continue;
166                 value <<= 4;
167                 if (isDigit(d))
168                     value += d - '0';
169                 else
170                     value += 10 + toUpper(d) - 'A';
171             }
172         }
173         else
174         {
175             static if (signed)
176             {
177                 bool negative = false;
178                 if (digits.startsWith("-"))
179                 {
180                     negative = true;
181                     digits = digits[1 .. $];
182                 }
183             }
184             foreach (d; digits)
185             {
186                 if (d == '_')
187                     continue;
188                 value *= 10;
189                 value += d - '0';
190             }
191             static if (signed)
192                 if (negative)
193                     value = -value;
194         }
195         return value;
196     }
197 
198     /// Construct from compile-time digit string.
199     ///
200     /// Both decimal and hex digit strings are supported.
201     ///
202     /// Example:
203     /// ----
204     /// auto x = int128.literal!"20_000_000_000_000_000_001";
205     /// assert((x >>> 1) == 0x8AC7_2304_89E8_0000);
206     ///
207     /// auto y = int126.literal!"0x1_158E_4609_13D0_0001";
208     /// assert(y == x);
209     /// ----
210     template literal(string digits)
211     {
212         static assert(isValidDigitString(digits),
213                       "invalid digits in literal: " ~ digits);
214         enum literal = literalImpl(digits);
215     }
216 
217     /// Assign with a smaller unsigned type.
218     @nogc ref self opAssign(T)(T n) pure nothrow if (isIntegral!T && isUnsigned!T)
219     {
220         hi = 0;
221         lo = n;
222         return this;
223     }
224 
225     /// Assign with a smaller signed type (sign is extended).
226     @nogc ref self opAssign(T)(T n) pure nothrow if (isIntegral!T && isSigned!T)
227     {
228         // shorter int always gets sign-extended,
229         // regardless of the larger int being signed or not
230         hi = (n < 0) ? cast(hi_t)(-1) : cast(hi_t)0;
231 
232         // will also sign extend as well if needed
233         lo = cast(sub_int_t)n;
234         return this;
235     }
236 
237     /// Assign with a wide integer of the same size (sign is lost).
238     @nogc ref self opAssign(T)(T n) pure nothrow if (isWideIntInstantiation!T && T._bits == bits)
239     {
240         hi = n.hi;
241         lo = n.lo;
242         return this;
243     }
244 
245     /// Assign with a smaller wide integer (sign is extended accordingly).
246     @nogc ref self opAssign(T)(T n) pure nothrow if (isWideIntInstantiation!T && T._bits < bits)
247     {
248         static if (T._signed)
249         {
250             // shorter int always gets sign-extended,
251             // regardless of the larger int being signed or not
252             hi = cast(hi_t)((n < 0) ? -1 : 0);
253 
254             // will also sign extend as well if needed
255             lo = cast(sub_int_t)n;
256             return this;
257         }
258         else
259         {
260             hi = 0;
261             lo = n;
262             return this;
263         }
264     }
265 
266     /// Cast to a smaller integer type (truncation).
267     @nogc T opCast(T)() pure const nothrow if (isIntegral!T)
268     {
269         return cast(T)lo;
270     }
271 
272     /// Cast to bool.
273     @nogc T opCast(T)() pure const nothrow if (is(T == bool))
274     {
275         return this != 0;
276     }
277 
278     /// Cast to wide integer of any size.
279     @nogc T opCast(T)() pure const nothrow if (isWideIntInstantiation!T)
280     {
281         static if (T._bits < bits)
282             return cast(T)lo;
283         else
284             return T(this);
285     }
286 
287     /// Converts to a string. Supports format specifiers %d, %s (both decimal)
288     /// and %x (hex).
289     void toString(DG, Char)(DG sink, FormatSpec!Char fmt) const
290         if (is(typeof(sink((const(Char)[]).init))))
291     {
292         if (fmt.spec == 'x')
293         {
294             if (this == 0)
295             {
296                 sink("0");
297                 return;
298             }
299 
300             enum maxDigits = bits / 4;
301             Char[maxDigits] buf;
302             wideIntImpl tmp = this;
303             size_t i;
304 
305             for (i = maxDigits-1; tmp != 0 && i < buf.length; i--)
306             {
307                 buf[i] = hexDigits[cast(int)tmp & 0b00001111];
308                 tmp >>= 4;
309             }
310             assert(i+1 < buf.length);
311             sink(buf[i+1 .. $]);
312         }
313         else // default to decimal
314         {
315             import std.algorithm : reverse;
316 
317             if (this == 0)
318             {
319                 sink("0");
320                 return;
321             }
322 
323             // The maximum number of decimal digits is basically
324             // ceil(log_10(2^^bits - 1)), which is slightly below
325             // ceil(bits * log(2)/log(10)). The value 0.30103 is a slight
326             // overestimate of log(2)/log(10), to be sure we never
327             // underestimate. We add 1 to account for rounding up.
328             enum maxDigits = cast(ulong)(0.30103 * bits) + 1;
329             Char[maxDigits] buf;
330             size_t i;
331 
332             wideIntImpl tmp = this;
333             if (tmp < 0)
334             {
335                 sink("-");
336                 tmp = -tmp;
337             }
338             for (i = maxDigits-1; tmp > 0; i--)
339             {
340                 assert(i < buf.length);
341                 buf[i] = digits[cast(int)(tmp % 10)];
342                 tmp /= 10;
343             }
344             assert(i+1 < buf.length);
345             sink(buf[i+1 .. $]);
346         }
347     }
348 
349     @nogc self opBinary(string op, T)(T o) pure const nothrow if (!isSelf!T)
350     {
351         self r = this;
352         self y = o;
353         return r.opOpAssign!(op)(y);
354     }
355 
356     @nogc self opBinary(string op, T)(T y) pure const nothrow if (isSelf!T)
357     {
358         self r = this; // copy
359         self o = y;
360         return r.opOpAssign!(op)(o);
361     }
362 
363     @nogc ref self opOpAssign(string op, T)(T y) pure nothrow if (!isSelf!T)
364     {
365         const(self) o = y;
366         return opOpAssign!(op)(o);
367     }
368 
369     @nogc ref self opOpAssign(string op, T)(T y) pure nothrow if (isSelf!T)
370     {
371         static if (op == "+")
372         {
373             hi += y.hi;
374             if (lo + y.lo < lo) // deal with overflow
375                 ++hi;
376             lo += y.lo;
377         }
378         else static if (op == "-")
379         {
380             opOpAssign!"+"(-y);
381         }
382         else static if (op == "<<")
383         {
384             if (y >= bits)
385             {
386                 hi = 0;
387                 lo = 0;
388             }
389             else if (y >= bits / 2)
390             {
391                 hi = lo << (y.lo - bits / 2);
392                 lo = 0;
393             }
394             else if (y > 0)
395             {
396                 hi = (lo >>> (-y.lo + bits / 2)) | (hi << y.lo);
397                 lo = lo << y.lo;
398             }
399         }
400         else static if (op == ">>" || op == ">>>")
401         {
402             assert(y >= 0);
403             static if (!signed || op == ">>>")
404                 immutable(sub_int_t) signFill = 0;
405             else
406                 immutable(sub_int_t) signFill = cast(sub_int_t)(isNegative() ? -1 : 0);
407 
408             if (y >= bits)
409             {
410                 hi = signFill;
411                 lo = signFill;
412             }
413             else if (y >= bits/2)
414             {
415                 lo = hi >> (y.lo - bits/2);
416                 hi = signFill;
417             }
418             else if (y > 0)
419             {
420                 lo = (hi << (-y.lo + bits/2)) | (lo >> y.lo);
421                 hi = hi >> y.lo;
422             }
423         }
424         else static if (op == "*")
425         {
426             sub_sub_uint_t[4] a = toParts();
427             sub_sub_uint_t[4] b = y.toParts();
428 
429             this = 0;
430             for(int i = 0; i < 4; ++i)
431                 for(int j = 0; j < 4 - i; ++j)
432                     this += self(cast(sub_uint_t)(a[i]) * b[j]) << ((bits/4) * (i + j));
433         }
434         else static if (op == "&")
435         {
436             hi &= y.hi;
437             lo &= y.lo;
438         }
439         else static if (op == "|")
440         {
441             hi |= y.hi;
442             lo |= y.lo;
443         }
444         else static if (op == "^")
445         {
446             hi ^= y.hi;
447             lo ^= y.lo;
448         }
449         else static if (op == "/" || op == "%")
450         {
451             self q = void, r = void;
452             static if(signed)
453                 Internals!bits.signedDivide(this, y, q, r);
454             else
455                 Internals!bits.unsignedDivide(this, y, q, r);
456             static if (op == "/")
457                 this = q;
458             else
459                 this = r;
460         }
461         else
462         {
463             static assert(false, "unsupported operation '" ~ op ~ "'");
464         }
465         return this;
466     }
467 
468     // const unary operations
469     @nogc self opUnary(string op)() pure const nothrow if (op == "+" || op == "-" || op == "~")
470     {
471         static if (op == "-")
472         {
473             self r = this;
474             r.not();
475             r.increment();
476             return r;
477         }
478         else static if (op == "+")
479            return this;
480         else static if (op == "~")
481         {
482             self r = this;
483             r.not();
484             return r;
485         }
486     }
487 
488     // non-const unary operations
489     @nogc self opUnary(string op)() pure nothrow if (op == "++" || op == "--")
490     {
491         static if (op == "++")
492             increment();
493         else static if (op == "--")
494             decrement();
495         return this;
496     }
497 
498     @nogc bool opEquals(T)(T y) pure const if (!isSelf!T)
499     {
500         return this == self(y);
501     }
502 
503     @nogc bool opEquals(T)(T y) pure const if (isSelf!T)
504     {
505        return lo == y.lo && y.hi == hi;
506     }
507 
508     @nogc int opCmp(T)(T y) pure const if (!isSelf!T)
509     {
510         return opCmp(self(y));
511     }
512 
513     @nogc int opCmp(T)(T y) pure const if (isSelf!T)
514     {
515         if (hi < y.hi) return -1;
516         if (hi > y.hi) return 1;
517         if (lo < y.lo) return -1;
518         if (lo > y.lo) return 1;
519         return 0;
520     }
521 
522     // binary layout should be what is expected on this platform
523     version (LittleEndian)
524     {
525         low_t lo;
526         hi_t hi;
527     }
528     else
529     {
530         hi_t hi;
531         low_t lo;
532     }
533 
534     private
535     {
536         static if (signed)
537         {
538             @nogc bool isNegative() pure nothrow const
539             {
540                 return signBit();
541             }
542         }
543         else
544         {
545             @nogc bool isNegative() pure nothrow const
546             {
547                 return false;
548             }
549         }
550 
551         @nogc void not() pure nothrow
552         {
553             hi = ~hi;
554             lo = ~lo;
555         }
556 
557         @nogc void increment() pure nothrow
558         {
559             ++lo;
560             if (lo == 0) ++hi;
561         }
562 
563         @nogc void decrement() pure nothrow
564         {
565             if (lo == 0) --hi;
566             --lo;
567         }
568 
569         @nogc bool signBit() pure const nothrow
570         {
571             enum SIGN_SHIFT = bits / 2 - 1;
572             return ((hi >> SIGN_SHIFT) & 1) != 0;
573         }
574 
575         @nogc sub_sub_uint_t[4] toParts() pure const nothrow
576         {
577             sub_sub_uint_t[4] p = void;
578             enum SHIFT = bits / 4;
579             immutable lomask = cast(sub_uint_t)(cast(sub_sub_int_t)(-1));
580             p[3] = cast(sub_sub_uint_t)(hi >> SHIFT);
581             p[2] = cast(sub_sub_uint_t)(hi & lomask);
582             p[1] = cast(sub_sub_uint_t)(lo >> SHIFT);
583             p[0] = cast(sub_sub_uint_t)(lo & lomask);
584             return p;
585         }
586     }
587 }
588 
589 template isWideIntInstantiation(U)
590 {
591     private static void isWideInt(bool signed, int bits)(wideIntImpl!(signed, bits) x)
592     {
593     }
594 
595     enum bool isWideIntInstantiation = is(typeof(isWideInt(U.init)));
596 }
597 
598 @nogc public wideIntImpl!(signed, bits) abs(bool signed, int bits)(wideIntImpl!(signed, bits) x) pure nothrow
599 {
600     if(x >= 0)
601         return x;
602     else
603         return -x;
604 }
605 
606 private struct Internals(int bits)
607 {
608     alias wideIntImpl!(true, bits) wint_t;
609     alias wideIntImpl!(false, bits) uwint_t;
610 
611     @nogc static void unsignedDivide(uwint_t dividend, uwint_t divisor,
612                                      out uwint_t quotient, out uwint_t remainder) pure nothrow
613     {
614         assert(divisor != 0);
615 
616         uwint_t rQuotient = 0;
617         uwint_t cDividend = dividend;
618 
619         while (divisor <= cDividend)
620         {
621             // find N so that (divisor << N) <= cDividend && cDividend < (divisor << (N + 1) )
622 
623             uwint_t N = 0;
624             uwint_t cDivisor = divisor;
625             while (cDividend > cDivisor)
626             {
627                 if (cDivisor.signBit())
628                     break;
629 
630                 if (cDividend < (cDivisor << 1))
631                     break;
632 
633                 cDivisor <<= 1;
634                 ++N;
635             }
636             cDividend = cDividend - cDivisor;
637             rQuotient += (uwint_t(1) << N);
638         }
639 
640         quotient = rQuotient;
641         remainder = cDividend;
642     }
643 
644     @nogc static void signedDivide(wint_t dividend, wint_t divisor,
645                                    out wint_t quotient, out wint_t remainder) pure nothrow
646     {
647         uwint_t q, r;
648         unsignedDivide(uwint_t(abs(dividend)), uwint_t(abs(divisor)), q, r);
649 
650         // remainder has same sign as the dividend
651         if (dividend < 0)
652             r = -r;
653 
654         // negate the quotient if opposite signs
655         if ((dividend >= 0) != (divisor >= 0))
656             q = -q;
657 
658         quotient = q;
659         remainder = r;
660 
661         assert(remainder == 0 || ((remainder < 0) == (dividend < 0)));
662     }
663 }
664 
665 // Verify that toString is callable from pure / nothrow / @nogc code as long as
666 // the callback also has these attributes.
667 pure nothrow @nogc unittest
668 {
669     int256 x = 123;
670     FormatSpec!char fspec;
671 
672     fspec.spec = 's';
673     x.toString((const(char)[]) {}, fspec);
674 
675     // Verify that wide strings actually work
676     FormatSpec!dchar dfspec;
677     dfspec.spec = 's';
678     x.toString((const(dchar)[] x) { assert(x == "123"); }, dfspec);
679 }
680 
681 unittest
682 {
683     import std.string : format;
684 
685     int128 x;
686     x.hi = 1;
687     x.lo = 0x158E_4609_13D0_0001;
688     assert(format("%s", x) == "20000000000000000001");
689     assert(format("%d", x) == "20000000000000000001");
690     assert(format("%x", x) == "1158E460913D00001");
691 
692     x.hi = 0xFFFF_FFFF_FFFF_FFFE;
693     x.lo = 0xEA71_B9F6_EC2F_FFFF;
694     assert(format("%d", x) == "-20000000000000000001");
695     assert(format("%x", x) == "FFFFFFFFFFFFFFFEEA71B9F6EC2FFFFF");
696 
697     x.hi = x.lo = 0;
698     assert(format("%d", x) == "0");
699 
700     x.hi = x.lo = 0xFFFF_FFFF_FFFF_FFFF;
701     assert(format("%d", x) == "-1"); // array index boundary condition
702 }
703 
704 unittest
705 {
706     long step = 164703072086692425;
707     for (long si = long.min; si <= long.max - step; si += step)
708     {
709         for (long sj = long.min; sj <= long.max - step; sj += step)
710         {
711             ulong ui = cast(ulong)si;
712             ulong uj = cast(ulong)sj;
713             int128 csi = si;
714             uint128 cui = si;
715             int128 csj = sj;
716             uint128 cuj = sj;
717             assert(csi == csi);
718             assert(~~csi == csi);
719             assert(-(-csi) == csi);
720             assert(++csi == si + 1);
721             assert(--csi == si);
722 
723             string testSigned(string op)
724             {
725                 return "assert(cast(ulong)(si" ~ op ~ "sj) == cast(ulong)(csi" ~ op ~ "csj));";
726             }
727 
728             string testMixed(string op)
729             {
730                 return "assert(cast(ulong)(ui" ~ op ~ "sj) == cast(ulong)(cui" ~ op ~ "csj));"
731                      ~ "assert(cast(ulong)(si" ~ op ~ "uj) == cast(ulong)(csi" ~ op ~ "cuj));";
732             }
733 
734             string testUnsigned(string op)
735             {
736                 return "assert(cast(ulong)(ui" ~ op ~ "uj) == cast(ulong)(cui" ~ op ~ "cuj));";
737             }
738 
739             string testAll(string op)
740             {
741                 return testSigned(op) ~ testMixed(op) ~ testUnsigned(op);
742             }
743 
744             mixin(testAll("+"));
745             mixin(testAll("-"));
746             mixin(testAll("*"));
747             mixin(testAll("|"));
748             mixin(testAll("&"));
749             mixin(testAll("^"));
750             if (sj != 0)
751             {
752                 mixin(testSigned("/"));
753                 mixin(testSigned("%"));
754                 if (si >= 0 && sj >= 0)
755                 {
756                     // those operations are not supposed to be the same at
757                     // higher bitdepth: a sign-extended negative may yield higher dividend
758                     testMixed("/");
759                     testUnsigned("/");
760                     testMixed("%");
761                     testUnsigned("%");
762                 }
763             }
764         }
765     }
766 }
767 
768 unittest
769 {
770     // Just a little over 2^64, so it actually needs int128.
771     // Hex value should be 0x1_158E_4609_13D0_0001.
772     enum x = int128.literal!"20_000_000_000_000_000_001";
773     assert(x.hi == 0x1 && x.lo == 0x158E_4609_13D0_0001);
774     assert((x >>> 1) == 0x8AC7_2304_89E8_0000);
775 
776     enum y = int128.literal!"0x1_158E_4609_13D0_0001";
777     enum z = int128.literal!"0x1_158e_4609_13d0_0001"; // case insensitivity
778     assert(x == y && y == z && x == z);
779 }
780 
781 unittest
782 {
783     import std.string : format;
784 
785     // Malformed literals that should be rejected
786     assert(!__traits(compiles, int128.literal!""));
787     assert(!__traits(compiles, int128.literal!"-"));
788 
789     // Negative literals should be supported
790     auto x = int128.literal!"-20000000000000000001";
791     assert(x.hi == 0xFFFF_FFFF_FFFF_FFFE &&
792            x.lo == 0xEA71_B9F6_EC2F_FFFF);
793     assert(format("%d", x) == "-20000000000000000001");
794     assert(format("%x", x) == "FFFFFFFFFFFFFFFEEA71B9F6EC2FFFFF");
795 
796     // Negative literals should not be supported for unsigned types
797     assert(!__traits(compiles, uint128.literal!"-1"));
798 
799     // Hex formatting tests
800     x = 0;
801     assert(format("%x", x) == "0");
802     x = -1;
803     assert(format("%x", x) == "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF");
804 }