numpy  2.0.0
src/private/npy_extint128.h
Go to the documentation of this file.
00001 #ifndef NPY_EXTINT128_H_
00002 #define NPY_EXTINT128_H_
00003 
00004 
00005 typedef struct {
00006     signed char sign;
00007     npy_uint64 lo, hi;
00008 } npy_extint128_t;
00009 
00010 
00011 /* Integer addition with overflow checking */
00012 static NPY_INLINE npy_int64
00013 safe_add(npy_int64 a, npy_int64 b, char *overflow_flag)
00014 {
00015     if (a > 0 && b > NPY_MAX_INT64 - a) {
00016         *overflow_flag = 1;
00017     }
00018     else if (a < 0 && b < NPY_MIN_INT64 - a) {
00019         *overflow_flag = 1;
00020     }
00021     return a + b;
00022 }
00023 
00024 
00025 /* Integer subtraction with overflow checking */
00026 static NPY_INLINE npy_int64
00027 safe_sub(npy_int64 a, npy_int64 b, char *overflow_flag)
00028 {
00029     if (a >= 0 && b < a - NPY_MAX_INT64) {
00030         *overflow_flag = 1;
00031     }
00032     else if (a < 0 && b > a - NPY_MIN_INT64) {
00033         *overflow_flag = 1;
00034     }
00035     return a - b;
00036 }
00037 
00038 
00039 /* Integer multiplication with overflow checking */
00040 static NPY_INLINE npy_int64
00041 safe_mul(npy_int64 a, npy_int64 b, char *overflow_flag)
00042 {
00043     if (a > 0) {
00044         if (b > NPY_MAX_INT64 / a || b < NPY_MIN_INT64 / a) {
00045             *overflow_flag = 1;
00046         }
00047     }
00048     else if (a < 0) {
00049         if (b > 0 && a < NPY_MIN_INT64 / b) {
00050             *overflow_flag = 1;
00051         }
00052         else if (b < 0 && a < NPY_MAX_INT64 / b) {
00053             *overflow_flag = 1;
00054         }
00055     }
00056     return a * b;
00057 }
00058 
00059 
00060 /* Long integer init */
00061 static NPY_INLINE npy_extint128_t
00062 to_128(npy_int64 x)
00063 {
00064     npy_extint128_t result;
00065     result.sign = (x >= 0 ? 1 : -1);
00066     if (x >= 0) {
00067         result.lo = x;
00068     }
00069     else {
00070         result.lo = (npy_uint64)(-(x + 1)) + 1;
00071     }
00072     result.hi = 0;
00073     return result;
00074 }
00075 
00076 
00077 static NPY_INLINE npy_int64
00078 to_64(npy_extint128_t x, char *overflow)
00079 {
00080     if (x.hi != 0 ||
00081         (x.sign > 0 && x.lo > NPY_MAX_INT64) ||
00082         (x.sign < 0 && x.lo != 0 && x.lo - 1 > -(NPY_MIN_INT64 + 1))) {
00083         *overflow = 1;
00084     }
00085     return x.lo * x.sign;
00086 }
00087 
00088 
00089 /* Long integer multiply */
00090 static NPY_INLINE npy_extint128_t
00091 mul_64_64(npy_int64 a, npy_int64 b)
00092 {
00093     npy_extint128_t x, y, z;
00094     npy_uint64 x1, x2, y1, y2, r1, r2, prev;
00095 
00096     x = to_128(a);
00097     y = to_128(b);
00098 
00099     x1 = x.lo & 0xffffffff;
00100     x2 = x.lo >> 32;
00101 
00102     y1 = y.lo & 0xffffffff;
00103     y2 = y.lo >> 32;
00104 
00105     r1 = x1*y2;
00106     r2 = x2*y1;
00107 
00108     z.sign = x.sign * y.sign;
00109     z.hi = x2*y2 + (r1 >> 32) + (r2 >> 32);
00110     z.lo = x1*y1;
00111 
00112     /* Add with carry */
00113     prev = z.lo;
00114     z.lo += (r1 << 32);
00115     if (z.lo < prev) {
00116         ++z.hi;
00117     }
00118 
00119     prev = z.lo;
00120     z.lo += (r2 << 32);
00121     if (z.lo < prev) {
00122         ++z.hi;
00123     }
00124 
00125     return z;
00126 }
00127 
00128 
00129 /* Long integer add */
00130 static NPY_INLINE npy_extint128_t
00131 add_128(npy_extint128_t x, npy_extint128_t y, char *overflow)
00132 {
00133     npy_extint128_t z;
00134 
00135     if (x.sign == y.sign) {
00136         z.sign = x.sign;
00137         z.hi = x.hi + y.hi;
00138         if (z.hi < x.hi) {
00139             *overflow = 1;
00140         }
00141         z.lo = x.lo + y.lo;
00142         if (z.lo < x.lo) {
00143             if (z.hi == NPY_MAX_UINT64) {
00144                 *overflow = 1;
00145             }
00146             ++z.hi;
00147         }
00148     }
00149     else if (x.hi > y.hi || (x.hi == y.hi && x.lo >= y.lo)) {
00150         z.sign = x.sign;
00151         z.hi = x.hi - y.hi;
00152         z.lo = x.lo;
00153         z.lo -= y.lo;
00154         if (z.lo > x.lo) {
00155             --z.hi;
00156         }
00157     }
00158     else {
00159         z.sign = y.sign;
00160         z.hi = y.hi - x.hi;
00161         z.lo = y.lo;
00162         z.lo -= x.lo;
00163         if (z.lo > y.lo) {
00164             --z.hi;
00165         }
00166     }
00167 
00168     return z;
00169 }
00170 
00171 
00172 /* Long integer negation */
00173 static NPY_INLINE npy_extint128_t
00174 neg_128(npy_extint128_t x)
00175 {
00176     npy_extint128_t z = x;
00177     z.sign *= -1;
00178     return z;
00179 }
00180 
00181 
00182 static NPY_INLINE npy_extint128_t
00183 sub_128(npy_extint128_t x, npy_extint128_t y, char *overflow)
00184 {
00185     return add_128(x, neg_128(y), overflow);
00186 }
00187 
00188 
00189 static NPY_INLINE npy_extint128_t
00190 shl_128(npy_extint128_t v)
00191 {
00192     npy_extint128_t z;
00193     z = v;
00194     z.hi <<= 1;
00195     z.hi |= (z.lo & (((npy_uint64)1) << 63)) >> 63;
00196     z.lo <<= 1;
00197     return z;
00198 }
00199 
00200 
00201 static NPY_INLINE npy_extint128_t
00202 shr_128(npy_extint128_t v)
00203 {
00204     npy_extint128_t z;
00205     z = v;
00206     z.lo >>= 1;
00207     z.lo |= (z.hi & 0x1) << 63;
00208     z.hi >>= 1;
00209     return z;
00210 }
00211 
00212 static NPY_INLINE int
00213 gt_128(npy_extint128_t a, npy_extint128_t b)
00214 {
00215     if (a.sign > 0 && b.sign > 0) {
00216         return (a.hi > b.hi) || (a.hi == b.hi && a.lo > b.lo);
00217     }
00218     else if (a.sign < 0 && b.sign < 0) {
00219         return (a.hi < b.hi) || (a.hi == b.hi && a.lo < b.lo);
00220     }
00221     else if (a.sign > 0 && b.sign < 0) {
00222         return a.hi != 0 || a.lo != 0 || b.hi != 0 || b.lo != 0;
00223     }
00224     else {
00225         return 0;
00226     }
00227 }
00228 
00229 
00230 /* Long integer divide */
00231 static NPY_INLINE npy_extint128_t
00232 divmod_128_64(npy_extint128_t x, npy_int64 b, npy_int64 *mod)
00233 {
00234     npy_extint128_t remainder, pointer, result, divisor;
00235     char overflow = 0;
00236 
00237     assert(b > 0);
00238 
00239     if (b <= 1 || x.hi == 0) {
00240         result.sign = x.sign;
00241         result.lo = x.lo / b;
00242         result.hi = x.hi / b;
00243         *mod = x.sign * (x.lo % b);
00244         return result;
00245     }
00246 
00247     /* Long division, not the most efficient choice */
00248     remainder = x;
00249     remainder.sign = 1;
00250 
00251     divisor.sign = 1;
00252     divisor.hi = 0;
00253     divisor.lo = b;
00254 
00255     result.sign = 1;
00256     result.lo = 0;
00257     result.hi = 0;
00258 
00259     pointer.sign = 1;
00260     pointer.lo = 1;
00261     pointer.hi = 0;
00262 
00263     while ((divisor.hi & (((npy_uint64)1) << 63)) == 0 &&
00264            gt_128(remainder, divisor)) {
00265         divisor = shl_128(divisor);
00266         pointer = shl_128(pointer);
00267     }
00268 
00269     while (pointer.lo || pointer.hi) {
00270         if (!gt_128(divisor, remainder)) {
00271             remainder = sub_128(remainder, divisor, &overflow);
00272             result = add_128(result, pointer, &overflow);
00273         }
00274         divisor = shr_128(divisor);
00275         pointer = shr_128(pointer);
00276     }
00277 
00278     /* Fix signs and return; cannot overflow */
00279     result.sign = x.sign;
00280     *mod = x.sign * remainder.lo;
00281 
00282     return result;
00283 }
00284 
00285 
00286 /* Divide and round down (positive divisor; no overflows) */
00287 static NPY_INLINE npy_extint128_t
00288 floordiv_128_64(npy_extint128_t a, npy_int64 b)
00289 {
00290     npy_extint128_t result;
00291     npy_int64 remainder;
00292     char overflow = 0;
00293     assert(b > 0);
00294     result = divmod_128_64(a, b, &remainder);
00295     if (a.sign < 0 && remainder != 0) {
00296         result = sub_128(result, to_128(1), &overflow);
00297     }
00298     return result;
00299 }
00300 
00301 
00302 /* Divide and round up (positive divisor; no overflows) */
00303 static NPY_INLINE npy_extint128_t
00304 ceildiv_128_64(npy_extint128_t a, npy_int64 b)
00305 {
00306     npy_extint128_t result;
00307     npy_int64 remainder;
00308     char overflow = 0;
00309     assert(b > 0);
00310     result = divmod_128_64(a, b, &remainder);
00311     if (a.sign > 0 && remainder != 0) {
00312         result = add_128(result, to_128(1), &overflow);
00313     }
00314     return result;
00315 }
00316 
00317 #endif