1 /**
2   Provide a 2^N-bit integer type.
3   Guaranteed to never allocate and expected binary layout
4   Recursive implementation with very slow division.
5  
6   <b>Supports all operations that builtin integers support.</b>
7   
8   Bugs: it's not sure if the unsigned operand would take precedence in a comparison/division.
9   TODO: add literals.  
10  */
11 module gfm.math.wideint;
12 
13 import std.traits,
14        std.ascii;
15 
16 /// Wide signed integer.
17 /// Params:
18 ///    bits = number of bits, must be a power of 2.
19 template wideint(int bits)
20 {
21     alias integer!(true, bits) wideint;
22 }
23 
24 /// Wide unsigned integer.
25 /// Params:
26 ///    bits = number of bits, must be a power of 2.
27 template uwideint(int bits)
28 {
29     alias integer!(false, bits) uwideint;
30 }
31 
32 // Some predefined integers (any power of 2 greater than 128 would work)
33 
34 alias wideint!128 int128; // cent and ucent!
35 alias uwideint!128 uint128;
36 
37 alias wideint!256 int256;
38 alias uwideint!256 uint256;
39 
40 /// Use this template to get an arbitrary sized integer type.
41 private template integer(bool signed, int bits)
42 {
43     static assert((bits & (bits - 1)) == 0); // bits must be a power of 2
44 
45     // forward to native type for lower numbers of bits
46     static if (bits == 8)
47     {
48         static if (signed)
49             alias byte integer;
50         else
51             alias ubyte integer;
52     }
53     else static if (bits == 16)
54     {
55         static if (signed)
56             alias short integer;
57         else
58             alias ushort integer;
59     }
60     else static if (bits == 32)
61     {
62         static if (signed)
63             alias int integer;
64         else
65             alias uint integer;
66     }
67     else static if (bits == 64)
68     {
69         static if (signed)
70             alias long integer;
71         else
72             alias ulong integer;
73     }
74     else 
75     {
76         alias wideIntImpl!(signed, bits) integer;
77     }
78 }
79 
80 /// Recursive 2^n integer implementation.
81 align(1) struct wideIntImpl(bool signed, int bits)
82 {
83     align(1):
84     static assert(bits >= 128);
85     private
86     {
87         alias wideIntImpl self;
88 
89         template isSelf(T)
90         {
91             enum bool isSelf = is(Unqual!T == self);
92         }
93 
94         alias integer!(true, bits/2) sub_int_t;   // signed bits/2 integer
95         alias integer!(false, bits/2) sub_uint_t; // unsigned bits/2 integer
96 
97         alias integer!(true, bits/4) sub_sub_int_t;   // signed bits/4 integer
98         alias integer!(false, bits/4) sub_sub_uint_t; // unsigned bits/4 integer
99 
100         static if(signed)
101             alias sub_int_t hi_t; // hi_t has same signedness as the whole struct
102         else
103             alias sub_uint_t hi_t;
104 
105         alias sub_uint_t low_t;   // low_t is always unsigned
106 
107         enum _isWideIntImpl = true,
108              _bits = bits,
109              _signed = signed;
110     }
111 
112     /// Construct from a value.
113     this(T)(T x) pure nothrow
114     {
115         opAssign!T(x);
116     }
117 
118     /// Assign with a smaller unsigned type.
119     ref self opAssign(T)(T n) pure nothrow if (isIntegral!T && isUnsigned!T)
120     {
121         hi = 0;
122         lo = n;
123         return this;
124     }
125 
126     /// Assign with a smaller signed type (sign is extended).
127     ref self opAssign(T)(T n) pure nothrow if (isIntegral!T && isSigned!T)
128     {
129         // shorter int always gets sign-extended,
130         // regardless of the larger int being signed or not
131         hi = (n < 0) ? cast(hi_t)(-1) : cast(hi_t)0;
132 
133         // will also sign extend as well if needed
134         lo = cast(sub_int_t)n;
135         return this;
136     }
137 
138     /// Assign with a wide integer of the same size (sign is lost).
139     ref self opAssign(T)(T n) pure nothrow if (is(typeof(T._isWideIntImpl)) && T._bits == bits) 
140     {
141         hi = n.hi;
142         lo = n.lo;
143         return this;
144     }
145 
146     /// Assign with a smaller wide integer (sign is extended accordingly).
147     ref self opAssign(T)(T n) pure nothrow if (is(typeof(T._isWideIntImpl)) && T._bits < bits) 
148     {
149         static if (T._signed)
150         {
151             // shorter int always gets sign-extended,
152             // regardless of the larger int being signed or not
153             hi = cast(hi_t)((n < 0) ? -1 : 0);
154 
155             // will also sign extend as well if needed
156             lo = cast(sub_int_t)n;
157             return this;            
158         }
159         else
160         {
161             hi = 0;
162             lo = n;
163             return this;
164         }
165     }
166 
167     /// Cast to a smaller integer type (truncation).
168     T opCast(T)() pure const nothrow if (isIntegral!T)
169     {
170         return cast(T)lo;
171     }
172 
173     /// Cast to bool.
174     T opCast(T)() pure const nothrow if (is(T == bool))
175     {
176         return this != 0;
177     }
178 
179     /// Cast to wide integer of any size.
180     T opCast(T)() pure const nothrow if (is(typeof(T._isWideIntImpl)))
181     {
182         static if (T._bits < bits)
183             return cast(T)lo;
184         else
185             return T(this);
186     }
187 
188     /// Converts to a hexadecimal string.
189     string toString() pure const nothrow
190     {
191         string outbuff = "0x";
192         enum hexdigits = bits / 8;
193 
194         for (size_t i = 0; i < hexdigits; ++i)
195         {
196             outbuff ~= hexDigits[cast(int)((hi >> ((15 - i) * 4)) & 15)];
197         }
198         for (size_t i = 0; i < hexdigits; ++i)
199         {
200             outbuff ~= hexDigits[cast(int)((lo >> ((15 - i) * 4)) & 15)];
201         }
202         return outbuff;
203     }
204 
205     self opBinary(string op, T)(T o) pure const nothrow if (!isSelf!T)
206     {
207         self r = this;
208         self y = o;
209         return r.opOpAssign!(op)(y);
210     }
211 
212     self opBinary(string op, T)(T y) pure const nothrow if (isSelf!T)
213     {
214         self r = this; // copy
215         self o = y;
216         return r.opOpAssign!(op)(o);
217     }
218 
219     ref self opOpAssign(string op, T)(T y) pure nothrow if (!isSelf!T)
220     {
221         const(self) o = y;
222         return opOpAssign!(op)(o);
223     }
224 
225     ref self opOpAssign(string op, T)(T y) pure nothrow if (isSelf!T)
226     {
227         static if (op == "+")
228         {
229             hi += y.hi;
230             if (lo + y.lo < lo) // deal with overflow
231                 ++hi;
232             lo += y.lo;
233         }
234         else static if (op == "-")
235         {
236             opOpAssign!"+"(-y);
237         }
238         else static if (op == "<<")
239         {
240             if (y >= bits)
241             {
242                 hi = 0;
243                 lo = 0;
244             }
245             else if (y >= bits / 2)
246             {
247                 hi = lo << (y.lo - bits / 2);
248                 lo = 0;
249             }
250             else if (y > 0)
251             {
252                 hi = (lo >>> (-y.lo + bits / 2)) | (hi << y.lo);
253                 lo = lo << y.lo;
254             }
255         }
256         else static if (op == ">>" || op == ">>>")
257         {
258             assert(y >= 0);
259             static if (!signed || op == ">>>")
260                 immutable(sub_int_t) signFill = 0;
261             else
262                 immutable(sub_int_t) signFill = cast(sub_int_t)(isNegative() ? -1 : 0);
263 
264             if (y >= bits)
265             {
266                 hi = signFill;
267                 lo = signFill;
268             }
269             else if (y >= bits/2)
270             {
271                 lo = hi >> (y.lo - bits/2);
272                 hi = signFill;
273             }
274             else if (y > 0)
275             {
276                 lo = (hi << (-y.lo + bits/2)) | (lo >> y.lo);
277                 hi = hi >> y.lo;
278             }
279         }
280         else static if (op == "*")
281         {
282             sub_sub_uint_t[4] a = toParts();
283             sub_sub_uint_t[4] b = y.toParts();
284 
285             this = 0;
286             for(size_t i = 0; i < 4; ++i)
287                 for(size_t j = 0; j < 4 - i; ++j)
288                     this += self(cast(sub_uint_t)(a[i]) * b[j]) << ((bits/4) * (i + j));
289         }
290         else static if (op == "&")
291         {
292             hi &= y.hi;
293             lo &= y.lo;
294         }
295         else static if (op == "|")
296         {
297             hi |= y.hi;
298             lo |= y.lo;
299         }
300         else static if (op == "^")
301         {
302             hi ^= y.hi;
303             lo ^= y.lo;
304         }
305         else static if (op == "/" || op == "%")
306         {
307             self q = void, r = void;
308             static if(signed)
309                 Internals!bits.signedDivide(this, y, q, r);
310             else
311                 Internals!bits.unsignedDivide(this, y, q, r);
312             static if (op == "/")
313                 this = q;
314             else
315                 this = r;
316         }
317         else
318         {
319             static assert(false, "unsupported operation '" ~ op ~ "'");
320         }
321         return this;
322     }
323 
324     // const unary operations
325     self opUnary(string op)() pure const nothrow if (op == "+" || op == "-" || op == "~")
326     {
327         static if (op == "-")
328         {
329             self r = this;
330             r.not();
331             r.increment();
332             return r;
333         }
334         else static if (op == "+")
335            return this;
336         else static if (op == "~")
337         {
338             self r = this;
339             r.not();
340             return r;
341         }
342     }
343 
344     // non-const unary operations
345     self opUnary(string op)() pure nothrow if (op == "++" || op == "--")
346     {
347         static if (op == "++")
348             increment();
349         else static if (op == "--")
350             decrement();
351         return this;
352     }
353 
354     bool opEquals(T)(T y) pure const if (!isSelf!T)
355     {
356         return this == self(y);
357     }
358 
359     bool opEquals(T)(T y) pure const if (isSelf!T)
360     {
361        return lo == y.lo && y.hi == hi;
362     }
363 
364     int opCmp(T)(T y) pure const if (!isSelf!T)
365     {
366         return opCmp(self(y));
367     }
368 
369     int opCmp(T)(T y) pure const if (isSelf!T)
370     {
371         if (hi < y.hi) return -1;
372         if (hi > y.hi) return 1;
373         if (lo < y.lo) return -1;
374         if (lo > y.lo) return 1;
375         return 0;
376     }
377 
378     // binary layout should be what is expected on this platform
379     version (LittleEndian)
380     {
381         low_t lo;
382         hi_t hi;
383     }
384     else
385     {
386         hi_t hi;
387         low_t lo;
388     }
389 
390     private
391     {
392         static if (signed)
393         {
394             bool isNegative() pure nothrow const
395             {
396                 return signBit();
397             }
398         }
399         else
400         {
401             bool isNegative() pure nothrow const
402             {
403                 return false;
404             }
405         }
406 
407         void not() pure nothrow
408         {
409             hi = ~hi;
410             lo = ~lo;
411         }
412 
413         void increment() pure nothrow
414         {
415             ++lo;
416             if (lo == 0) ++hi;
417         }
418 
419         void decrement() pure nothrow
420         {
421             if (lo == 0) --hi;
422             --lo;
423         }
424 
425         bool signBit() pure const nothrow
426         {
427             enum SIGN_SHIFT = bits / 2 - 1;
428             return ((hi >> SIGN_SHIFT) & 1) != 0;
429         }
430 
431         sub_sub_uint_t[4] toParts() pure const nothrow
432         {
433             sub_sub_uint_t[4] p = void;
434             enum SHIFT = bits / 4;
435             immutable lomask = cast(sub_uint_t)(cast(sub_sub_int_t)(-1));
436             p[3] = cast(sub_sub_uint_t)(hi >> SHIFT);
437             p[2] = cast(sub_sub_uint_t)(hi & lomask);
438             p[1] = cast(sub_sub_uint_t)(lo >> SHIFT);
439             p[0] = cast(sub_sub_uint_t)(lo & lomask);
440             return p;
441         }
442     }
443 }
444 
445 public wideIntImpl!(signed, bits) abs(bool signed, int bits)(wideIntImpl!(signed, bits) x) pure nothrow
446 {
447     if(x >= 0)
448         return x;
449     else
450         return -x;
451 }
452 
453 private struct Internals(int bits)
454 {    
455     alias wideIntImpl!(true, bits) wint_t;
456     alias wideIntImpl!(false, bits) uwint_t;
457 
458     static void unsignedDivide(uwint_t dividend, uwint_t divisor,
459                                out uwint_t quotient, out uwint_t remainder) pure nothrow
460     {
461         assert(divisor != 0);
462 
463         uwint_t rQuotient = 0;
464         uwint_t cDividend = dividend;
465 
466         while (divisor <= cDividend)
467         {
468             // find N so that (divisor << N) <= cDividend && cDividend < (divisor << (N + 1) )
469 
470             uwint_t N = 0;
471             uwint_t cDivisor = divisor;
472             while (cDividend > cDivisor) 
473             {
474                 if (cDivisor.signBit())
475                     break;
476 
477                 if (cDividend < (cDivisor << 1))
478                     break;
479 
480                 cDivisor <<= 1;
481                 ++N;
482             }
483             cDividend = cDividend - cDivisor;
484             rQuotient += (uwint_t(1) << N);
485         }
486 
487         quotient = rQuotient;
488         remainder = cDividend;
489     }
490 
491     static void signedDivide(wint_t dividend, wint_t divisor,
492                              out wint_t quotient, out wint_t remainder) pure nothrow
493     {
494         uwint_t q, r;
495         unsignedDivide(uwint_t(abs(dividend)), uwint_t(abs(divisor)), q, r);
496 
497         // remainder has same sign as the dividend
498         if (dividend < 0)
499             r = -r;
500 
501         // negate the quotient if opposite signs
502         if ((dividend >= 0) != (divisor >= 0))
503             q = -q;
504 
505         quotient = q;
506         remainder = r;
507 
508         assert(remainder == 0 || ((remainder < 0) == (dividend < 0)));
509     }
510 }
511 
512 unittest
513 {
514     long step = 164703072086692425;
515     for (long si = long.min; si <= long.max - step; si += step)
516     {
517         for (long sj = long.min; sj <= long.max - step; sj += step)
518         {
519             ulong ui = cast(ulong)si;
520             ulong uj = cast(ulong)sj;
521             int128 csi = si;
522             uint128 cui = si;
523             int128 csj = sj;
524             uint128 cuj = sj;
525             assert(csi == csi);
526             assert(~~csi == csi);
527             assert(-(-csi) == csi);
528             assert(++csi == si + 1);
529             assert(--csi == si);
530 
531             string testSigned(string op)
532             {
533                 return "assert(cast(ulong)(si" ~ op ~ "sj) == cast(ulong)(csi" ~ op ~ "csj));";
534             }
535 
536             string testMixed(string op)
537             {
538                 return "assert(cast(ulong)(ui" ~ op ~ "sj) == cast(ulong)(cui" ~ op ~ "csj));"
539                      ~ "assert(cast(ulong)(si" ~ op ~ "uj) == cast(ulong)(csi" ~ op ~ "cuj));";
540             }
541 
542             string testUnsigned(string op)
543             {
544                 return "assert(cast(ulong)(ui" ~ op ~ "uj) == cast(ulong)(cui" ~ op ~ "cuj));";
545             }
546 
547             string testAll(string op)
548             {
549                 return testSigned(op) ~ testMixed(op) ~ testUnsigned(op);
550             }
551 
552             mixin(testAll("+"));
553             mixin(testAll("-"));
554             mixin(testAll("*"));
555             mixin(testAll("|"));
556             mixin(testAll("&"));
557             mixin(testAll("^"));
558             if (sj != 0)
559             {
560                 mixin(testSigned("/"));
561                 mixin(testSigned("%"));
562                 if (si >= 0 && sj >= 0)
563                 {
564                     // those operations are not supposed to be the same at
565                     // higher bitdepth: a sign-extended negative may yield higher dividend
566                     testMixed("/");
567                     testUnsigned("/");
568                     testMixed("%");
569                     testUnsigned("%");
570                 }
571             }
572         }
573     }
574 }
575