numpy
2.0.0
|
00001 #ifndef NPY_EXTINT128_H_ 00002 #define NPY_EXTINT128_H_ 00003 00004 00005 typedef struct { 00006 signed char sign; 00007 npy_uint64 lo, hi; 00008 } npy_extint128_t; 00009 00010 00011 /* Integer addition with overflow checking */ 00012 static NPY_INLINE npy_int64 00013 safe_add(npy_int64 a, npy_int64 b, char *overflow_flag) 00014 { 00015 if (a > 0 && b > NPY_MAX_INT64 - a) { 00016 *overflow_flag = 1; 00017 } 00018 else if (a < 0 && b < NPY_MIN_INT64 - a) { 00019 *overflow_flag = 1; 00020 } 00021 return a + b; 00022 } 00023 00024 00025 /* Integer subtraction with overflow checking */ 00026 static NPY_INLINE npy_int64 00027 safe_sub(npy_int64 a, npy_int64 b, char *overflow_flag) 00028 { 00029 if (a >= 0 && b < a - NPY_MAX_INT64) { 00030 *overflow_flag = 1; 00031 } 00032 else if (a < 0 && b > a - NPY_MIN_INT64) { 00033 *overflow_flag = 1; 00034 } 00035 return a - b; 00036 } 00037 00038 00039 /* Integer multiplication with overflow checking */ 00040 static NPY_INLINE npy_int64 00041 safe_mul(npy_int64 a, npy_int64 b, char *overflow_flag) 00042 { 00043 if (a > 0) { 00044 if (b > NPY_MAX_INT64 / a || b < NPY_MIN_INT64 / a) { 00045 *overflow_flag = 1; 00046 } 00047 } 00048 else if (a < 0) { 00049 if (b > 0 && a < NPY_MIN_INT64 / b) { 00050 *overflow_flag = 1; 00051 } 00052 else if (b < 0 && a < NPY_MAX_INT64 / b) { 00053 *overflow_flag = 1; 00054 } 00055 } 00056 return a * b; 00057 } 00058 00059 00060 /* Long integer init */ 00061 static NPY_INLINE npy_extint128_t 00062 to_128(npy_int64 x) 00063 { 00064 npy_extint128_t result; 00065 result.sign = (x >= 0 ? 1 : -1); 00066 if (x >= 0) { 00067 result.lo = x; 00068 } 00069 else { 00070 result.lo = (npy_uint64)(-(x + 1)) + 1; 00071 } 00072 result.hi = 0; 00073 return result; 00074 } 00075 00076 00077 static NPY_INLINE npy_int64 00078 to_64(npy_extint128_t x, char *overflow) 00079 { 00080 if (x.hi != 0 || 00081 (x.sign > 0 && x.lo > NPY_MAX_INT64) || 00082 (x.sign < 0 && x.lo != 0 && x.lo - 1 > -(NPY_MIN_INT64 + 1))) { 00083 *overflow = 1; 00084 } 00085 return x.lo * x.sign; 00086 } 00087 00088 00089 /* Long integer multiply */ 00090 static NPY_INLINE npy_extint128_t 00091 mul_64_64(npy_int64 a, npy_int64 b) 00092 { 00093 npy_extint128_t x, y, z; 00094 npy_uint64 x1, x2, y1, y2, r1, r2, prev; 00095 00096 x = to_128(a); 00097 y = to_128(b); 00098 00099 x1 = x.lo & 0xffffffff; 00100 x2 = x.lo >> 32; 00101 00102 y1 = y.lo & 0xffffffff; 00103 y2 = y.lo >> 32; 00104 00105 r1 = x1*y2; 00106 r2 = x2*y1; 00107 00108 z.sign = x.sign * y.sign; 00109 z.hi = x2*y2 + (r1 >> 32) + (r2 >> 32); 00110 z.lo = x1*y1; 00111 00112 /* Add with carry */ 00113 prev = z.lo; 00114 z.lo += (r1 << 32); 00115 if (z.lo < prev) { 00116 ++z.hi; 00117 } 00118 00119 prev = z.lo; 00120 z.lo += (r2 << 32); 00121 if (z.lo < prev) { 00122 ++z.hi; 00123 } 00124 00125 return z; 00126 } 00127 00128 00129 /* Long integer add */ 00130 static NPY_INLINE npy_extint128_t 00131 add_128(npy_extint128_t x, npy_extint128_t y, char *overflow) 00132 { 00133 npy_extint128_t z; 00134 00135 if (x.sign == y.sign) { 00136 z.sign = x.sign; 00137 z.hi = x.hi + y.hi; 00138 if (z.hi < x.hi) { 00139 *overflow = 1; 00140 } 00141 z.lo = x.lo + y.lo; 00142 if (z.lo < x.lo) { 00143 if (z.hi == NPY_MAX_UINT64) { 00144 *overflow = 1; 00145 } 00146 ++z.hi; 00147 } 00148 } 00149 else if (x.hi > y.hi || (x.hi == y.hi && x.lo >= y.lo)) { 00150 z.sign = x.sign; 00151 z.hi = x.hi - y.hi; 00152 z.lo = x.lo; 00153 z.lo -= y.lo; 00154 if (z.lo > x.lo) { 00155 --z.hi; 00156 } 00157 } 00158 else { 00159 z.sign = y.sign; 00160 z.hi = y.hi - x.hi; 00161 z.lo = y.lo; 00162 z.lo -= x.lo; 00163 if (z.lo > y.lo) { 00164 --z.hi; 00165 } 00166 } 00167 00168 return z; 00169 } 00170 00171 00172 /* Long integer negation */ 00173 static NPY_INLINE npy_extint128_t 00174 neg_128(npy_extint128_t x) 00175 { 00176 npy_extint128_t z = x; 00177 z.sign *= -1; 00178 return z; 00179 } 00180 00181 00182 static NPY_INLINE npy_extint128_t 00183 sub_128(npy_extint128_t x, npy_extint128_t y, char *overflow) 00184 { 00185 return add_128(x, neg_128(y), overflow); 00186 } 00187 00188 00189 static NPY_INLINE npy_extint128_t 00190 shl_128(npy_extint128_t v) 00191 { 00192 npy_extint128_t z; 00193 z = v; 00194 z.hi <<= 1; 00195 z.hi |= (z.lo & (((npy_uint64)1) << 63)) >> 63; 00196 z.lo <<= 1; 00197 return z; 00198 } 00199 00200 00201 static NPY_INLINE npy_extint128_t 00202 shr_128(npy_extint128_t v) 00203 { 00204 npy_extint128_t z; 00205 z = v; 00206 z.lo >>= 1; 00207 z.lo |= (z.hi & 0x1) << 63; 00208 z.hi >>= 1; 00209 return z; 00210 } 00211 00212 static NPY_INLINE int 00213 gt_128(npy_extint128_t a, npy_extint128_t b) 00214 { 00215 if (a.sign > 0 && b.sign > 0) { 00216 return (a.hi > b.hi) || (a.hi == b.hi && a.lo > b.lo); 00217 } 00218 else if (a.sign < 0 && b.sign < 0) { 00219 return (a.hi < b.hi) || (a.hi == b.hi && a.lo < b.lo); 00220 } 00221 else if (a.sign > 0 && b.sign < 0) { 00222 return a.hi != 0 || a.lo != 0 || b.hi != 0 || b.lo != 0; 00223 } 00224 else { 00225 return 0; 00226 } 00227 } 00228 00229 00230 /* Long integer divide */ 00231 static NPY_INLINE npy_extint128_t 00232 divmod_128_64(npy_extint128_t x, npy_int64 b, npy_int64 *mod) 00233 { 00234 npy_extint128_t remainder, pointer, result, divisor; 00235 char overflow = 0; 00236 00237 assert(b > 0); 00238 00239 if (b <= 1 || x.hi == 0) { 00240 result.sign = x.sign; 00241 result.lo = x.lo / b; 00242 result.hi = x.hi / b; 00243 *mod = x.sign * (x.lo % b); 00244 return result; 00245 } 00246 00247 /* Long division, not the most efficient choice */ 00248 remainder = x; 00249 remainder.sign = 1; 00250 00251 divisor.sign = 1; 00252 divisor.hi = 0; 00253 divisor.lo = b; 00254 00255 result.sign = 1; 00256 result.lo = 0; 00257 result.hi = 0; 00258 00259 pointer.sign = 1; 00260 pointer.lo = 1; 00261 pointer.hi = 0; 00262 00263 while ((divisor.hi & (((npy_uint64)1) << 63)) == 0 && 00264 gt_128(remainder, divisor)) { 00265 divisor = shl_128(divisor); 00266 pointer = shl_128(pointer); 00267 } 00268 00269 while (pointer.lo || pointer.hi) { 00270 if (!gt_128(divisor, remainder)) { 00271 remainder = sub_128(remainder, divisor, &overflow); 00272 result = add_128(result, pointer, &overflow); 00273 } 00274 divisor = shr_128(divisor); 00275 pointer = shr_128(pointer); 00276 } 00277 00278 /* Fix signs and return; cannot overflow */ 00279 result.sign = x.sign; 00280 *mod = x.sign * remainder.lo; 00281 00282 return result; 00283 } 00284 00285 00286 /* Divide and round down (positive divisor; no overflows) */ 00287 static NPY_INLINE npy_extint128_t 00288 floordiv_128_64(npy_extint128_t a, npy_int64 b) 00289 { 00290 npy_extint128_t result; 00291 npy_int64 remainder; 00292 char overflow = 0; 00293 assert(b > 0); 00294 result = divmod_128_64(a, b, &remainder); 00295 if (a.sign < 0 && remainder != 0) { 00296 result = sub_128(result, to_128(1), &overflow); 00297 } 00298 return result; 00299 } 00300 00301 00302 /* Divide and round up (positive divisor; no overflows) */ 00303 static NPY_INLINE npy_extint128_t 00304 ceildiv_128_64(npy_extint128_t a, npy_int64 b) 00305 { 00306 npy_extint128_t result; 00307 npy_int64 remainder; 00308 char overflow = 0; 00309 assert(b > 0); 00310 result = divmod_128_64(a, b, &remainder); 00311 if (a.sign > 0 && remainder != 0) { 00312 result = add_128(result, to_128(1), &overflow); 00313 } 00314 return result; 00315 } 00316 00317 #endif