1 /**
2   Provide a 2^N-bit integer type.
3   Guaranteed to never allocate and expected binary layout
4   Recursive implementation with very slow division.
5 
6   <b>Supports all operations that builtin integers support.</b>
7 
8   Bugs: it's not sure if the unsigned operand would take precedence in a comparison/division.
9   TODO: add literals.
10  */
11 module gfm.integers.wideint;
12 
13 import std.traits,
14        std.ascii;
15 import std.format : FormatSpec;
16 
17 /// Wide signed integer.
18 /// Params:
19 ///    bits = number of bits, must be a power of 2.
20 template wideint(int bits)
21 {
22     alias integer!(true, bits) wideint;
23 }
24 
25 /// Wide unsigned integer.
26 /// Params:
27 ///    bits = number of bits, must be a power of 2.
28 template uwideint(int bits)
29 {
30     alias integer!(false, bits) uwideint;
31 }
32 
33 // Some predefined integers (any power of 2 greater than 128 would work)
34 
35 alias wideint!128 int128; // cent and ucent!
36 alias uwideint!128 uint128;
37 
38 alias wideint!256 int256;
39 alias uwideint!256 uint256;
40 
41 /// Use this template to get an arbitrary sized integer type.
42 private template integer(bool signed, int bits)
43     if ((bits & (bits - 1)) == 0)
44 {
45 
46     // forward to native type for lower numbers of bits
47     static if (bits == 8)
48     {
49         static if (signed)
50             alias byte integer;
51         else
52             alias ubyte integer;
53     }
54     else static if (bits == 16)
55     {
56         static if (signed)
57             alias short integer;
58         else
59             alias ushort integer;
60     }
61     else static if (bits == 32)
62     {
63         static if (signed)
64             alias int integer;
65         else
66             alias uint integer;
67     }
68     else static if (bits == 64)
69     {
70         static if (signed)
71             alias long integer;
72         else
73             alias ulong integer;
74     }
75     else
76     {
77         alias wideIntImpl!(signed, bits) integer;
78     }
79 }
80 
81 private template integer(bool signed, int bits)
82     if ((bits & (bits - 1)) != 0)
83 {
84     static assert(0, "wide integer bits must be a power of 2");
85 }
86 
87 /// Recursive 2^n integer implementation.
88 struct wideIntImpl(bool signed, int bits)
89 {
90     static assert(bits >= 128);
91     private
92     {
93         alias wideIntImpl self;
94 
95         template isSelf(T)
96         {
97             enum bool isSelf = is(Unqual!T == self);
98         }
99 
100         alias integer!(true, bits/2) sub_int_t;   // signed bits/2 integer
101         alias integer!(false, bits/2) sub_uint_t; // unsigned bits/2 integer
102 
103         alias integer!(true, bits/4) sub_sub_int_t;   // signed bits/4 integer
104         alias integer!(false, bits/4) sub_sub_uint_t; // unsigned bits/4 integer
105 
106         static if(signed)
107             alias sub_int_t hi_t; // hi_t has same signedness as the whole struct
108         else
109             alias sub_uint_t hi_t;
110 
111         alias sub_uint_t low_t;   // low_t is always unsigned
112 
113         enum _bits = bits,
114              _signed = signed;
115     }
116 
117     /// Construct from a value.
118     @nogc this(T)(T x) pure nothrow
119     {
120         opAssign!T(x);
121     }
122 
123     // Private functions used by the `literal` template.
124     private static bool isValidDigitString(string digits)
125     {
126         import std.algorithm : startsWith;
127         import std.ascii : isDigit;
128 
129         if (digits.startsWith("0x"))
130         {
131             foreach (d; digits[2 .. $])
132             {
133                 if (!isHexDigit(d) && d != '_')
134                     return false;
135             }
136         }
137         else // decimal
138         {
139             static if (signed)
140                 if (digits.startsWith("-"))
141                     digits = digits[1 .. $];
142             if (digits.length < 1)
143                 return false;   // at least 1 digit required
144             foreach (d; digits)
145             {
146                 if (!isDigit(d) && d != '_')
147                     return false;
148             }
149         }
150         return true;
151     }
152 
153     private static typeof(this) literalImpl(string digits)
154     {
155         import std.algorithm : startsWith;
156         import std.ascii : isDigit;
157 
158         typeof(this) value = 0;
159         if (digits.startsWith("0x"))
160         {
161             foreach (d; digits[2 .. $])
162             {
163                 if (d == '_')
164                     continue;
165                 value <<= 4;
166                 if (isDigit(d))
167                     value += d - '0';
168                 else
169                     value += 10 + toUpper(d) - 'A';
170             }
171         }
172         else
173         {
174             static if (signed)
175             {
176                 bool negative = false;
177                 if (digits.startsWith("-"))
178                 {
179                     negative = true;
180                     digits = digits[1 .. $];
181                 }
182             }
183             foreach (d; digits)
184             {
185                 if (d == '_')
186                     continue;
187                 value *= 10;
188                 value += d - '0';
189             }
190             static if (signed)
191                 if (negative)
192                     value = -value;
193         }
194         return value;
195     }
196 
197     /// Construct from compile-time digit string.
198     ///
199     /// Both decimal and hex digit strings are supported.
200     ///
201     /// Example:
202     /// ----
203     /// auto x = int128.literal!"20_000_000_000_000_000_001";
204     /// assert((x >>> 1) == 0x8AC7_2304_89E8_0000);
205     ///
206     /// auto y = int126.literal!"0x1_158E_4609_13D0_0001";
207     /// assert(y == x);
208     /// ----
209     template literal(string digits)
210     {
211         static assert(isValidDigitString(digits),
212                       "invalid digits in literal: " ~ digits);
213         enum literal = literalImpl(digits);
214     }
215 
216     /// Assign with a smaller unsigned type.
217     @nogc ref self opAssign(T)(T n) pure nothrow if (isIntegral!T && isUnsigned!T)
218     {
219         hi = 0;
220         lo = n;
221         return this;
222     }
223 
224     /// Assign with a smaller signed type (sign is extended).
225     @nogc ref self opAssign(T)(T n) pure nothrow if (isIntegral!T && isSigned!T)
226     {
227         // shorter int always gets sign-extended,
228         // regardless of the larger int being signed or not
229         hi = (n < 0) ? cast(hi_t)(-1) : cast(hi_t)0;
230 
231         // will also sign extend as well if needed
232         lo = cast(sub_int_t)n;
233         return this;
234     }
235 
236     /// Assign with a wide integer of the same size (sign is lost).
237     @nogc ref self opAssign(T)(T n) pure nothrow if (isWideIntInstantiation!T && T._bits == bits)
238     {
239         hi = n.hi;
240         lo = n.lo;
241         return this;
242     }
243 
244     /// Assign with a smaller wide integer (sign is extended accordingly).
245     @nogc ref self opAssign(T)(T n) pure nothrow if (isWideIntInstantiation!T && T._bits < bits)
246     {
247         static if (T._signed)
248         {
249             // shorter int always gets sign-extended,
250             // regardless of the larger int being signed or not
251             hi = cast(hi_t)((n < 0) ? -1 : 0);
252 
253             // will also sign extend as well if needed
254             lo = cast(sub_int_t)n;
255             return this;
256         }
257         else
258         {
259             hi = 0;
260             lo = n;
261             return this;
262         }
263     }
264 
265     /// Cast to a smaller integer type (truncation).
266     @nogc T opCast(T)() pure const nothrow if (isIntegral!T)
267     {
268         return cast(T)lo;
269     }
270 
271     /// Cast to bool.
272     @nogc T opCast(T)() pure const nothrow if (is(T == bool))
273     {
274         return this != 0;
275     }
276 
277     /// Cast to wide integer of any size.
278     @nogc T opCast(T)() pure const nothrow if (isWideIntInstantiation!T)
279     {
280         static if (T._bits < bits)
281             return cast(T)lo;
282         else
283             return T(this);
284     }
285 
286     /// Converts to a string. Supports format specifiers %d, %s (both decimal)
287     /// and %x (hex).
288     void toString(DG, Char)(DG sink, FormatSpec!Char fmt) const
289         if (is(typeof(sink((const(Char)[]).init))))
290     {
291         if (fmt.spec == 'x')
292         {
293             if (this == 0)
294             {
295                 sink("0");
296                 return;
297             }
298 
299             enum maxDigits = bits / 4;
300             Char[maxDigits] buf;
301             wideIntImpl tmp = this;
302             size_t i;
303 
304             for (i = maxDigits-1; tmp != 0 && i < buf.length; i--)
305             {
306                 buf[i] = hexDigits[cast(int)tmp & 0b00001111];
307                 tmp >>= 4;
308             }
309             assert(i+1 < buf.length);
310             sink(buf[i+1 .. $]);
311         }
312         else // default to decimal
313         {
314             import std.algorithm : reverse;
315 
316             if (this == 0)
317             {
318                 sink("0");
319                 return;
320             }
321 
322             // The maximum number of decimal digits is basically
323             // ceil(log_10(2^^bits - 1)), which is slightly below
324             // ceil(bits * log(2)/log(10)). The value 0.30103 is a slight
325             // overestimate of log(2)/log(10), to be sure we never
326             // underestimate. We add 1 to account for rounding up.
327             enum maxDigits = cast(ulong)(0.30103 * bits) + 1;
328             Char[maxDigits] buf;
329             size_t i;
330 
331             wideIntImpl tmp = this;
332             if (tmp < 0)
333             {
334                 sink("-");
335                 tmp = -tmp;
336             }
337             for (i = maxDigits-1; tmp > 0; i--)
338             {
339                 assert(i < buf.length);
340                 buf[i] = digits[cast(int)(tmp % 10)];
341                 tmp /= 10;
342             }
343             assert(i+1 < buf.length);
344             sink(buf[i+1 .. $]);
345         }
346     }
347 
348     @nogc self opBinary(string op, T)(T o) pure const nothrow if (!isSelf!T)
349     {
350         self r = this;
351         self y = o;
352         return r.opOpAssign!(op)(y);
353     }
354 
355     @nogc self opBinary(string op, T)(T y) pure const nothrow if (isSelf!T)
356     {
357         self r = this; // copy
358         self o = y;
359         return r.opOpAssign!(op)(o);
360     }
361 
362     @nogc ref self opOpAssign(string op, T)(T y) pure nothrow if (!isSelf!T)
363     {
364         const(self) o = y;
365         return opOpAssign!(op)(o);
366     }
367 
368     @nogc ref self opOpAssign(string op, T)(T y) pure nothrow if (isSelf!T)
369     {
370         static if (op == "+")
371         {
372             hi += y.hi;
373             if (lo + y.lo < lo) // deal with overflow
374                 ++hi;
375             lo += y.lo;
376         }
377         else static if (op == "-")
378         {
379             opOpAssign!"+"(-y);
380         }
381         else static if (op == "<<")
382         {
383             if (y >= bits)
384             {
385                 hi = 0;
386                 lo = 0;
387             }
388             else if (y >= bits / 2)
389             {
390                 hi = lo << (y.lo - bits / 2);
391                 lo = 0;
392             }
393             else if (y > 0)
394             {
395                 hi = (lo >>> (-y.lo + bits / 2)) | (hi << y.lo);
396                 lo = lo << y.lo;
397             }
398         }
399         else static if (op == ">>" || op == ">>>")
400         {
401             assert(y >= 0);
402             static if (!signed || op == ">>>")
403                 immutable(sub_int_t) signFill = 0;
404             else
405                 immutable(sub_int_t) signFill = cast(sub_int_t)(isNegative() ? -1 : 0);
406 
407             if (y >= bits)
408             {
409                 hi = signFill;
410                 lo = signFill;
411             }
412             else if (y >= bits/2)
413             {
414                 lo = hi >> (y.lo - bits/2);
415                 hi = signFill;
416             }
417             else if (y > 0)
418             {
419                 lo = (hi << (-y.lo + bits/2)) | (lo >> y.lo);
420                 hi = hi >> y.lo;
421             }
422         }
423         else static if (op == "*")
424         {
425             sub_sub_uint_t[4] a = toParts();
426             sub_sub_uint_t[4] b = y.toParts();
427 
428             this = 0;
429             for(int i = 0; i < 4; ++i)
430                 for(int j = 0; j < 4 - i; ++j)
431                     this += self(cast(sub_uint_t)(a[i]) * b[j]) << ((bits/4) * (i + j));
432         }
433         else static if (op == "&")
434         {
435             hi &= y.hi;
436             lo &= y.lo;
437         }
438         else static if (op == "|")
439         {
440             hi |= y.hi;
441             lo |= y.lo;
442         }
443         else static if (op == "^")
444         {
445             hi ^= y.hi;
446             lo ^= y.lo;
447         }
448         else static if (op == "/" || op == "%")
449         {
450             self q = void, r = void;
451             static if(signed)
452                 Internals!bits.signedDivide(this, y, q, r);
453             else
454                 Internals!bits.unsignedDivide(this, y, q, r);
455             static if (op == "/")
456                 this = q;
457             else
458                 this = r;
459         }
460         else
461         {
462             static assert(false, "unsupported operation '" ~ op ~ "'");
463         }
464         return this;
465     }
466 
467     // const unary operations
468     @nogc self opUnary(string op)() pure const nothrow if (op == "+" || op == "-" || op == "~")
469     {
470         static if (op == "-")
471         {
472             self r = this;
473             r.not();
474             r.increment();
475             return r;
476         }
477         else static if (op == "+")
478            return this;
479         else static if (op == "~")
480         {
481             self r = this;
482             r.not();
483             return r;
484         }
485     }
486 
487     // non-const unary operations
488     @nogc self opUnary(string op)() pure nothrow if (op == "++" || op == "--")
489     {
490         static if (op == "++")
491             increment();
492         else static if (op == "--")
493             decrement();
494         return this;
495     }
496 
497     @nogc bool opEquals(T)(T y) pure const if (!isSelf!T)
498     {
499         return this == self(y);
500     }
501 
502     @nogc bool opEquals(T)(T y) pure const if (isSelf!T)
503     {
504        return lo == y.lo && y.hi == hi;
505     }
506 
507     @nogc int opCmp(T)(T y) pure const if (!isSelf!T)
508     {
509         return opCmp(self(y));
510     }
511 
512     @nogc int opCmp(T)(T y) pure const if (isSelf!T)
513     {
514         if (hi < y.hi) return -1;
515         if (hi > y.hi) return 1;
516         if (lo < y.lo) return -1;
517         if (lo > y.lo) return 1;
518         return 0;
519     }
520 
521     // binary layout should be what is expected on this platform
522     version (LittleEndian)
523     {
524         low_t lo;
525         hi_t hi;
526     }
527     else
528     {
529         hi_t hi;
530         low_t lo;
531     }
532 
533     private
534     {
535         static if (signed)
536         {
537             @nogc bool isNegative() pure nothrow const
538             {
539                 return signBit();
540             }
541         }
542         else
543         {
544             @nogc bool isNegative() pure nothrow const
545             {
546                 return false;
547             }
548         }
549 
550         @nogc void not() pure nothrow
551         {
552             hi = ~hi;
553             lo = ~lo;
554         }
555 
556         @nogc void increment() pure nothrow
557         {
558             ++lo;
559             if (lo == 0) ++hi;
560         }
561 
562         @nogc void decrement() pure nothrow
563         {
564             if (lo == 0) --hi;
565             --lo;
566         }
567 
568         @nogc bool signBit() pure const nothrow
569         {
570             enum SIGN_SHIFT = bits / 2 - 1;
571             return ((hi >> SIGN_SHIFT) & 1) != 0;
572         }
573 
574         @nogc sub_sub_uint_t[4] toParts() pure const nothrow
575         {
576             sub_sub_uint_t[4] p = void;
577             enum SHIFT = bits / 4;
578             immutable lomask = cast(sub_uint_t)(cast(sub_sub_int_t)(-1));
579             p[3] = cast(sub_sub_uint_t)(hi >> SHIFT);
580             p[2] = cast(sub_sub_uint_t)(hi & lomask);
581             p[1] = cast(sub_sub_uint_t)(lo >> SHIFT);
582             p[0] = cast(sub_sub_uint_t)(lo & lomask);
583             return p;
584         }
585     }
586 }
587 
588 template isWideIntInstantiation(U)
589 {
590     private static void isWideInt(bool signed, int bits)(wideIntImpl!(signed, bits) x)
591     {
592     }
593 
594     enum bool isWideIntInstantiation = is(typeof(isWideInt(U.init)));
595 }
596 
597 @nogc public wideIntImpl!(signed, bits) abs(bool signed, int bits)(wideIntImpl!(signed, bits) x) pure nothrow
598 {
599     if(x >= 0)
600         return x;
601     else
602         return -x;
603 }
604 
605 private struct Internals(int bits)
606 {
607     alias wideIntImpl!(true, bits) wint_t;
608     alias wideIntImpl!(false, bits) uwint_t;
609 
610     @nogc static void unsignedDivide(uwint_t dividend, uwint_t divisor,
611                                      out uwint_t quotient, out uwint_t remainder) pure nothrow
612     {
613         assert(divisor != 0);
614 
615         uwint_t rQuotient = 0;
616         uwint_t cDividend = dividend;
617 
618         while (divisor <= cDividend)
619         {
620             // find N so that (divisor << N) <= cDividend && cDividend < (divisor << (N + 1) )
621 
622             uwint_t N = 0;
623             uwint_t cDivisor = divisor;
624             while (cDividend > cDivisor)
625             {
626                 if (cDivisor.signBit())
627                     break;
628 
629                 if (cDividend < (cDivisor << 1))
630                     break;
631 
632                 cDivisor <<= 1;
633                 ++N;
634             }
635             cDividend = cDividend - cDivisor;
636             rQuotient += (uwint_t(1) << N);
637         }
638 
639         quotient = rQuotient;
640         remainder = cDividend;
641     }
642 
643     @nogc static void signedDivide(wint_t dividend, wint_t divisor,
644                                    out wint_t quotient, out wint_t remainder) pure nothrow
645     {
646         uwint_t q, r;
647         unsignedDivide(uwint_t(abs(dividend)), uwint_t(abs(divisor)), q, r);
648 
649         // remainder has same sign as the dividend
650         if (dividend < 0)
651             r = -r;
652 
653         // negate the quotient if opposite signs
654         if ((dividend >= 0) != (divisor >= 0))
655             q = -q;
656 
657         quotient = q;
658         remainder = r;
659 
660         assert(remainder == 0 || ((remainder < 0) == (dividend < 0)));
661     }
662 }
663 
664 // Verify that toString is callable from pure / nothrow / @nogc code as long as
665 // the callback also has these attributes.
666 pure nothrow @nogc unittest
667 {
668     int256 x = 123;
669     FormatSpec!char fspec;
670 
671     fspec.spec = 's';
672     x.toString((const(char)[]) {}, fspec);
673 
674     // Verify that wide strings actually work
675     FormatSpec!dchar dfspec;
676     dfspec.spec = 's';
677     x.toString((const(dchar)[] x) { assert(x == "123"); }, dfspec);
678 }
679 
680 unittest
681 {
682     import std.string : format;
683 
684     int128 x;
685     x.hi = 1;
686     x.lo = 0x158E_4609_13D0_0001;
687     assert(format("%s", x) == "20000000000000000001");
688     assert(format("%d", x) == "20000000000000000001");
689     assert(format("%x", x) == "1158E460913D00001");
690 
691     x.hi = 0xFFFF_FFFF_FFFF_FFFE;
692     x.lo = 0xEA71_B9F6_EC2F_FFFF;
693     assert(format("%d", x) == "-20000000000000000001");
694     assert(format("%x", x) == "FFFFFFFFFFFFFFFEEA71B9F6EC2FFFFF");
695 
696     x.hi = x.lo = 0;
697     assert(format("%d", x) == "0");
698 
699     x.hi = x.lo = 0xFFFF_FFFF_FFFF_FFFF;
700     assert(format("%d", x) == "-1"); // array index boundary condition
701 }
702 
703 unittest
704 {
705     long step = 164703072086692425;
706     for (long si = long.min; si <= long.max - step; si += step)
707     {
708         for (long sj = long.min; sj <= long.max - step; sj += step)
709         {
710             ulong ui = cast(ulong)si;
711             ulong uj = cast(ulong)sj;
712             int128 csi = si;
713             uint128 cui = si;
714             int128 csj = sj;
715             uint128 cuj = sj;
716             assert(csi == csi);
717             assert(~~csi == csi);
718             assert(-(-csi) == csi);
719             assert(++csi == si + 1);
720             assert(--csi == si);
721 
722             string testSigned(string op)
723             {
724                 return "assert(cast(ulong)(si" ~ op ~ "sj) == cast(ulong)(csi" ~ op ~ "csj));";
725             }
726 
727             string testMixed(string op)
728             {
729                 return "assert(cast(ulong)(ui" ~ op ~ "sj) == cast(ulong)(cui" ~ op ~ "csj));"
730                      ~ "assert(cast(ulong)(si" ~ op ~ "uj) == cast(ulong)(csi" ~ op ~ "cuj));";
731             }
732 
733             string testUnsigned(string op)
734             {
735                 return "assert(cast(ulong)(ui" ~ op ~ "uj) == cast(ulong)(cui" ~ op ~ "cuj));";
736             }
737 
738             string testAll(string op)
739             {
740                 return testSigned(op) ~ testMixed(op) ~ testUnsigned(op);
741             }
742 
743             mixin(testAll("+"));
744             mixin(testAll("-"));
745             mixin(testAll("*"));
746             mixin(testAll("|"));
747             mixin(testAll("&"));
748             mixin(testAll("^"));
749             if (sj != 0)
750             {
751                 mixin(testSigned("/"));
752                 mixin(testSigned("%"));
753                 if (si >= 0 && sj >= 0)
754                 {
755                     // those operations are not supposed to be the same at
756                     // higher bitdepth: a sign-extended negative may yield higher dividend
757                     testMixed("/");
758                     testUnsigned("/");
759                     testMixed("%");
760                     testUnsigned("%");
761                 }
762             }
763         }
764     }
765 }
766 
767 unittest
768 {
769     // Just a little over 2^64, so it actually needs int128.
770     // Hex value should be 0x1_158E_4609_13D0_0001.
771     enum x = int128.literal!"20_000_000_000_000_000_001";
772     assert(x.hi == 0x1 && x.lo == 0x158E_4609_13D0_0001);
773     assert((x >>> 1) == 0x8AC7_2304_89E8_0000);
774 
775     enum y = int128.literal!"0x1_158E_4609_13D0_0001";
776     enum z = int128.literal!"0x1_158e_4609_13d0_0001"; // case insensitivity
777     assert(x == y && y == z && x == z);
778 }
779 
780 unittest
781 {
782     import std.string : format;
783 
784     // Malformed literals that should be rejected
785     assert(!__traits(compiles, int128.literal!""));
786     assert(!__traits(compiles, int128.literal!"-"));
787 
788     // Negative literals should be supported
789     auto x = int128.literal!"-20000000000000000001";
790     assert(x.hi == 0xFFFF_FFFF_FFFF_FFFE &&
791            x.lo == 0xEA71_B9F6_EC2F_FFFF);
792     assert(format("%d", x) == "-20000000000000000001");
793     assert(format("%x", x) == "FFFFFFFFFFFFFFFEEA71B9F6EC2FFFFF");
794 
795     // Negative literals should not be supported for unsigned types
796     assert(!__traits(compiles, uint128.literal!"-1"));
797 
798     // Hex formatting tests
799     x = 0;
800     assert(format("%x", x) == "0");
801     x = -1;
802     assert(format("%x", x) == "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF");
803 }