1 /**
2   Provide a 2^N-bit integer type.
3   Guaranteed to never allocate and expected binary layout
4   Recursive implementation with very slow division.
5 
6   <b>Supports all operations that builtin integers support.</b>
7 
8   Bugs: it's not sure if the unsigned operand would take precedence in a comparison/division.
9   TODO: add literals.
10  */
11 module gfm.math.wideint;
12 
13 import std.traits,
14        std.ascii;
15 
16 /// Wide signed integer.
17 /// Params:
18 ///    bits = number of bits, must be a power of 2.
19 template wideint(int bits)
20 {
21     alias integer!(true, bits) wideint;
22 }
23 
24 /// Wide unsigned integer.
25 /// Params:
26 ///    bits = number of bits, must be a power of 2.
27 template uwideint(int bits)
28 {
29     alias integer!(false, bits) uwideint;
30 }
31 
32 // Some predefined integers (any power of 2 greater than 128 would work)
33 
34 alias wideint!128 int128; // cent and ucent!
35 alias uwideint!128 uint128;
36 
37 alias wideint!256 int256;
38 alias uwideint!256 uint256;
39 
40 /// Use this template to get an arbitrary sized integer type.
41 private template integer(bool signed, int bits)
42 {
43     static assert((bits & (bits - 1)) == 0); // bits must be a power of 2
44 
45     // forward to native type for lower numbers of bits
46     static if (bits == 8)
47     {
48         static if (signed)
49             alias byte integer;
50         else
51             alias ubyte integer;
52     }
53     else static if (bits == 16)
54     {
55         static if (signed)
56             alias short integer;
57         else
58             alias ushort integer;
59     }
60     else static if (bits == 32)
61     {
62         static if (signed)
63             alias int integer;
64         else
65             alias uint integer;
66     }
67     else static if (bits == 64)
68     {
69         static if (signed)
70             alias long integer;
71         else
72             alias ulong integer;
73     }
74     else
75     {
76         alias wideIntImpl!(signed, bits) integer;
77     }
78 }
79 
80 /// Recursive 2^n integer implementation.
81 struct wideIntImpl(bool signed, int bits)
82 {
83     static assert(bits >= 128);
84     private
85     {
86         alias wideIntImpl self;
87 
88         template isSelf(T)
89         {
90             enum bool isSelf = is(Unqual!T == self);
91         }
92 
93         alias integer!(true, bits/2) sub_int_t;   // signed bits/2 integer
94         alias integer!(false, bits/2) sub_uint_t; // unsigned bits/2 integer
95 
96         alias integer!(true, bits/4) sub_sub_int_t;   // signed bits/4 integer
97         alias integer!(false, bits/4) sub_sub_uint_t; // unsigned bits/4 integer
98 
99         static if(signed)
100             alias sub_int_t hi_t; // hi_t has same signedness as the whole struct
101         else
102             alias sub_uint_t hi_t;
103 
104         alias sub_uint_t low_t;   // low_t is always unsigned
105 
106         enum _isWideIntImpl = true,
107              _bits = bits,
108              _signed = signed;
109     }
110 
111     /// Construct from a value.
112     @nogc this(T)(T x) pure nothrow
113     {
114         opAssign!T(x);
115     }
116 
117     /// Assign with a smaller unsigned type.
118     @nogc ref self opAssign(T)(T n) pure nothrow if (isIntegral!T && isUnsigned!T)
119     {
120         hi = 0;
121         lo = n;
122         return this;
123     }
124 
125     /// Assign with a smaller signed type (sign is extended).
126     @nogc ref self opAssign(T)(T n) pure nothrow if (isIntegral!T && isSigned!T)
127     {
128         // shorter int always gets sign-extended,
129         // regardless of the larger int being signed or not
130         hi = (n < 0) ? cast(hi_t)(-1) : cast(hi_t)0;
131 
132         // will also sign extend as well if needed
133         lo = cast(sub_int_t)n;
134         return this;
135     }
136 
137     /// Assign with a wide integer of the same size (sign is lost).
138     @nogc ref self opAssign(T)(T n) pure nothrow if (is(typeof(T._isWideIntImpl)) && T._bits == bits)
139     {
140         hi = n.hi;
141         lo = n.lo;
142         return this;
143     }
144 
145     /// Assign with a smaller wide integer (sign is extended accordingly).
146     @nogc ref self opAssign(T)(T n) pure nothrow if (is(typeof(T._isWideIntImpl)) && T._bits < bits)
147     {
148         static if (T._signed)
149         {
150             // shorter int always gets sign-extended,
151             // regardless of the larger int being signed or not
152             hi = cast(hi_t)((n < 0) ? -1 : 0);
153 
154             // will also sign extend as well if needed
155             lo = cast(sub_int_t)n;
156             return this;
157         }
158         else
159         {
160             hi = 0;
161             lo = n;
162             return this;
163         }
164     }
165 
166     /// Cast to a smaller integer type (truncation).
167     @nogc T opCast(T)() pure const nothrow if (isIntegral!T)
168     {
169         return cast(T)lo;
170     }
171 
172     /// Cast to bool.
173     @nogc T opCast(T)() pure const nothrow if (is(T == bool))
174     {
175         return this != 0;
176     }
177 
178     /// Cast to wide integer of any size.
179     @nogc T opCast(T)() pure const nothrow if (is(typeof(T._isWideIntImpl)))
180     {
181         static if (T._bits < bits)
182             return cast(T)lo;
183         else
184             return T(this);
185     }
186 
187     /// Converts to a hexadecimal string.
188     string toString() pure const nothrow
189     {
190         string outbuff = "0x";
191         enum hexdigits = bits / 8;
192 
193         for (int i = 0; i < hexdigits; ++i)
194         {
195             outbuff ~= hexDigits[cast(int)((hi >> ((15 - i) * 4)) & 15)];
196         }
197         for (int i = 0; i < hexdigits; ++i)
198         {
199             outbuff ~= hexDigits[cast(int)((lo >> ((15 - i) * 4)) & 15)];
200         }
201         return outbuff;
202     }
203 
204     @nogc self opBinary(string op, T)(T o) pure const nothrow if (!isSelf!T)
205     {
206         self r = this;
207         self y = o;
208         return r.opOpAssign!(op)(y);
209     }
210 
211     @nogc self opBinary(string op, T)(T y) pure const nothrow if (isSelf!T)
212     {
213         self r = this; // copy
214         self o = y;
215         return r.opOpAssign!(op)(o);
216     }
217 
218     @nogc ref self opOpAssign(string op, T)(T y) pure nothrow if (!isSelf!T)
219     {
220         const(self) o = y;
221         return opOpAssign!(op)(o);
222     }
223 
224     @nogc ref self opOpAssign(string op, T)(T y) pure nothrow if (isSelf!T)
225     {
226         static if (op == "+")
227         {
228             hi += y.hi;
229             if (lo + y.lo < lo) // deal with overflow
230                 ++hi;
231             lo += y.lo;
232         }
233         else static if (op == "-")
234         {
235             opOpAssign!"+"(-y);
236         }
237         else static if (op == "<<")
238         {
239             if (y >= bits)
240             {
241                 hi = 0;
242                 lo = 0;
243             }
244             else if (y >= bits / 2)
245             {
246                 hi = lo << (y.lo - bits / 2);
247                 lo = 0;
248             }
249             else if (y > 0)
250             {
251                 hi = (lo >>> (-y.lo + bits / 2)) | (hi << y.lo);
252                 lo = lo << y.lo;
253             }
254         }
255         else static if (op == ">>" || op == ">>>")
256         {
257             assert(y >= 0);
258             static if (!signed || op == ">>>")
259                 immutable(sub_int_t) signFill = 0;
260             else
261                 immutable(sub_int_t) signFill = cast(sub_int_t)(isNegative() ? -1 : 0);
262 
263             if (y >= bits)
264             {
265                 hi = signFill;
266                 lo = signFill;
267             }
268             else if (y >= bits/2)
269             {
270                 lo = hi >> (y.lo - bits/2);
271                 hi = signFill;
272             }
273             else if (y > 0)
274             {
275                 lo = (hi << (-y.lo + bits/2)) | (lo >> y.lo);
276                 hi = hi >> y.lo;
277             }
278         }
279         else static if (op == "*")
280         {
281             sub_sub_uint_t[4] a = toParts();
282             sub_sub_uint_t[4] b = y.toParts();
283 
284             this = 0;
285             for(int i = 0; i < 4; ++i)
286                 for(int j = 0; j < 4 - i; ++j)
287                     this += self(cast(sub_uint_t)(a[i]) * b[j]) << ((bits/4) * (i + j));
288         }
289         else static if (op == "&")
290         {
291             hi &= y.hi;
292             lo &= y.lo;
293         }
294         else static if (op == "|")
295         {
296             hi |= y.hi;
297             lo |= y.lo;
298         }
299         else static if (op == "^")
300         {
301             hi ^= y.hi;
302             lo ^= y.lo;
303         }
304         else static if (op == "/" || op == "%")
305         {
306             self q = void, r = void;
307             static if(signed)
308                 Internals!bits.signedDivide(this, y, q, r);
309             else
310                 Internals!bits.unsignedDivide(this, y, q, r);
311             static if (op == "/")
312                 this = q;
313             else
314                 this = r;
315         }
316         else
317         {
318             static assert(false, "unsupported operation '" ~ op ~ "'");
319         }
320         return this;
321     }
322 
323     // const unary operations
324     @nogc self opUnary(string op)() pure const nothrow if (op == "+" || op == "-" || op == "~")
325     {
326         static if (op == "-")
327         {
328             self r = this;
329             r.not();
330             r.increment();
331             return r;
332         }
333         else static if (op == "+")
334            return this;
335         else static if (op == "~")
336         {
337             self r = this;
338             r.not();
339             return r;
340         }
341     }
342 
343     // non-const unary operations
344     @nogc self opUnary(string op)() pure nothrow if (op == "++" || op == "--")
345     {
346         static if (op == "++")
347             increment();
348         else static if (op == "--")
349             decrement();
350         return this;
351     }
352 
353     @nogc bool opEquals(T)(T y) pure const if (!isSelf!T)
354     {
355         return this == self(y);
356     }
357 
358     @nogc bool opEquals(T)(T y) pure const if (isSelf!T)
359     {
360        return lo == y.lo && y.hi == hi;
361     }
362 
363     @nogc int opCmp(T)(T y) pure const if (!isSelf!T)
364     {
365         return opCmp(self(y));
366     }
367 
368     @nogc int opCmp(T)(T y) pure const if (isSelf!T)
369     {
370         if (hi < y.hi) return -1;
371         if (hi > y.hi) return 1;
372         if (lo < y.lo) return -1;
373         if (lo > y.lo) return 1;
374         return 0;
375     }
376 
377     // binary layout should be what is expected on this platform
378     version (LittleEndian)
379     {
380         low_t lo;
381         hi_t hi;
382     }
383     else
384     {
385         hi_t hi;
386         low_t lo;
387     }
388 
389     private
390     {
391         static if (signed)
392         {
393             @nogc bool isNegative() pure nothrow const
394             {
395                 return signBit();
396             }
397         }
398         else
399         {
400             @nogc bool isNegative() pure nothrow const
401             {
402                 return false;
403             }
404         }
405 
406         @nogc void not() pure nothrow
407         {
408             hi = ~hi;
409             lo = ~lo;
410         }
411 
412         @nogc void increment() pure nothrow
413         {
414             ++lo;
415             if (lo == 0) ++hi;
416         }
417 
418         @nogc void decrement() pure nothrow
419         {
420             if (lo == 0) --hi;
421             --lo;
422         }
423 
424         @nogc bool signBit() pure const nothrow
425         {
426             enum SIGN_SHIFT = bits / 2 - 1;
427             return ((hi >> SIGN_SHIFT) & 1) != 0;
428         }
429 
430         @nogc sub_sub_uint_t[4] toParts() pure const nothrow
431         {
432             sub_sub_uint_t[4] p = void;
433             enum SHIFT = bits / 4;
434             immutable lomask = cast(sub_uint_t)(cast(sub_sub_int_t)(-1));
435             p[3] = cast(sub_sub_uint_t)(hi >> SHIFT);
436             p[2] = cast(sub_sub_uint_t)(hi & lomask);
437             p[1] = cast(sub_sub_uint_t)(lo >> SHIFT);
438             p[0] = cast(sub_sub_uint_t)(lo & lomask);
439             return p;
440         }
441     }
442 }
443 
444 @nogc public wideIntImpl!(signed, bits) abs(bool signed, int bits)(wideIntImpl!(signed, bits) x) pure nothrow
445 {
446     if(x >= 0)
447         return x;
448     else
449         return -x;
450 }
451 
452 private struct Internals(int bits)
453 {
454     alias wideIntImpl!(true, bits) wint_t;
455     alias wideIntImpl!(false, bits) uwint_t;
456 
457     @nogc static void unsignedDivide(uwint_t dividend, uwint_t divisor,
458                                      out uwint_t quotient, out uwint_t remainder) pure nothrow
459     {
460         assert(divisor != 0);
461 
462         uwint_t rQuotient = 0;
463         uwint_t cDividend = dividend;
464 
465         while (divisor <= cDividend)
466         {
467             // find N so that (divisor << N) <= cDividend && cDividend < (divisor << (N + 1) )
468 
469             uwint_t N = 0;
470             uwint_t cDivisor = divisor;
471             while (cDividend > cDivisor)
472             {
473                 if (cDivisor.signBit())
474                     break;
475 
476                 if (cDividend < (cDivisor << 1))
477                     break;
478 
479                 cDivisor <<= 1;
480                 ++N;
481             }
482             cDividend = cDividend - cDivisor;
483             rQuotient += (uwint_t(1) << N);
484         }
485 
486         quotient = rQuotient;
487         remainder = cDividend;
488     }
489 
490     @nogc static void signedDivide(wint_t dividend, wint_t divisor,
491                                    out wint_t quotient, out wint_t remainder) pure nothrow
492     {
493         uwint_t q, r;
494         unsignedDivide(uwint_t(abs(dividend)), uwint_t(abs(divisor)), q, r);
495 
496         // remainder has same sign as the dividend
497         if (dividend < 0)
498             r = -r;
499 
500         // negate the quotient if opposite signs
501         if ((dividend >= 0) != (divisor >= 0))
502             q = -q;
503 
504         quotient = q;
505         remainder = r;
506 
507         assert(remainder == 0 || ((remainder < 0) == (dividend < 0)));
508     }
509 }
510 
511 unittest
512 {
513     long step = 164703072086692425;
514     for (long si = long.min; si <= long.max - step; si += step)
515     {
516         for (long sj = long.min; sj <= long.max - step; sj += step)
517         {
518             ulong ui = cast(ulong)si;
519             ulong uj = cast(ulong)sj;
520             int128 csi = si;
521             uint128 cui = si;
522             int128 csj = sj;
523             uint128 cuj = sj;
524             assert(csi == csi);
525             assert(~~csi == csi);
526             assert(-(-csi) == csi);
527             assert(++csi == si + 1);
528             assert(--csi == si);
529 
530             string testSigned(string op)
531             {
532                 return "assert(cast(ulong)(si" ~ op ~ "sj) == cast(ulong)(csi" ~ op ~ "csj));";
533             }
534 
535             string testMixed(string op)
536             {
537                 return "assert(cast(ulong)(ui" ~ op ~ "sj) == cast(ulong)(cui" ~ op ~ "csj));"
538                      ~ "assert(cast(ulong)(si" ~ op ~ "uj) == cast(ulong)(csi" ~ op ~ "cuj));";
539             }
540 
541             string testUnsigned(string op)
542             {
543                 return "assert(cast(ulong)(ui" ~ op ~ "uj) == cast(ulong)(cui" ~ op ~ "cuj));";
544             }
545 
546             string testAll(string op)
547             {
548                 return testSigned(op) ~ testMixed(op) ~ testUnsigned(op);
549             }
550 
551             mixin(testAll("+"));
552             mixin(testAll("-"));
553             mixin(testAll("*"));
554             mixin(testAll("|"));
555             mixin(testAll("&"));
556             mixin(testAll("^"));
557             if (sj != 0)
558             {
559                 mixin(testSigned("/"));
560                 mixin(testSigned("%"));
561                 if (si >= 0 && sj >= 0)
562                 {
563                     // those operations are not supposed to be the same at
564                     // higher bitdepth: a sign-extended negative may yield higher dividend
565                     testMixed("/");
566                     testUnsigned("/");
567                     testMixed("%");
568                     testUnsigned("%");
569                 }
570             }
571         }
572     }
573 }