numpy
2.0.0
|
00001 #ifndef __UFUNC_OVERRIDE_H 00002 #define __UFUNC_OVERRIDE_H 00003 #include <npy_config.h> 00004 #include "numpy/arrayobject.h" 00005 #include "common.h" 00006 #include <string.h> 00007 #include "numpy/ufuncobject.h" 00008 00009 static void 00010 normalize___call___args(PyUFuncObject *ufunc, PyObject *args, 00011 PyObject **normal_args, PyObject **normal_kwds, 00012 int nin) 00013 { 00014 /* ufunc.__call__(*args, **kwds) */ 00015 int nargs = PyTuple_GET_SIZE(args); 00016 PyObject *obj = PyDict_GetItemString(*normal_kwds, "sig"); 00017 00018 /* ufuncs accept 'sig' or 'signature' normalize to 'signature' */ 00019 if (obj != NULL) { 00020 Py_INCREF(obj); 00021 PyDict_SetItemString(*normal_kwds, "signature", obj); 00022 PyDict_DelItemString(*normal_kwds, "sig"); 00023 } 00024 00025 *normal_args = PyTuple_GetSlice(args, 0, nin); 00026 00027 /* If we have more args than nin, they must be the output variables.*/ 00028 if (nargs > nin) { 00029 if ((nargs - nin) == 1) { 00030 obj = PyTuple_GET_ITEM(args, nargs - 1); 00031 PyDict_SetItemString(*normal_kwds, "out", obj); 00032 } 00033 else { 00034 obj = PyTuple_GetSlice(args, nin, nargs); 00035 PyDict_SetItemString(*normal_kwds, "out", obj); 00036 Py_DECREF(obj); 00037 } 00038 } 00039 } 00040 00041 static void 00042 normalize_reduce_args(PyUFuncObject *ufunc, PyObject *args, 00043 PyObject **normal_args, PyObject **normal_kwds) 00044 { 00045 /* ufunc.reduce(a[, axis, dtype, out, keepdims]) */ 00046 int nargs = PyTuple_GET_SIZE(args); 00047 int i; 00048 PyObject *obj; 00049 00050 for (i = 0; i < nargs; i++) { 00051 obj = PyTuple_GET_ITEM(args, i); 00052 if (i == 0) { 00053 *normal_args = PyTuple_GetSlice(args, 0, 1); 00054 } 00055 else if (i == 1) { 00056 /* axis */ 00057 PyDict_SetItemString(*normal_kwds, "axis", obj); 00058 } 00059 else if (i == 2) { 00060 /* dtype */ 00061 PyDict_SetItemString(*normal_kwds, "dtype", obj); 00062 } 00063 else if (i == 3) { 00064 /* out */ 00065 PyDict_SetItemString(*normal_kwds, "out", obj); 00066 } 00067 else { 00068 /* keepdims */ 00069 PyDict_SetItemString(*normal_kwds, "keepdims", obj); 00070 } 00071 } 00072 return; 00073 } 00074 00075 static void 00076 normalize_accumulate_args(PyUFuncObject *ufunc, PyObject *args, 00077 PyObject **normal_args, PyObject **normal_kwds) 00078 { 00079 /* ufunc.accumulate(a[, axis, dtype, out]) */ 00080 int nargs = PyTuple_GET_SIZE(args); 00081 int i; 00082 PyObject *obj; 00083 00084 for (i = 0; i < nargs; i++) { 00085 obj = PyTuple_GET_ITEM(args, i); 00086 if (i == 0) { 00087 *normal_args = PyTuple_GetSlice(args, 0, 1); 00088 } 00089 else if (i == 1) { 00090 /* axis */ 00091 PyDict_SetItemString(*normal_kwds, "axis", obj); 00092 } 00093 else if (i == 2) { 00094 /* dtype */ 00095 PyDict_SetItemString(*normal_kwds, "dtype", obj); 00096 } 00097 else { 00098 /* out */ 00099 PyDict_SetItemString(*normal_kwds, "out", obj); 00100 } 00101 } 00102 return; 00103 } 00104 00105 static void 00106 normalize_reduceat_args(PyUFuncObject *ufunc, PyObject *args, 00107 PyObject **normal_args, PyObject **normal_kwds) 00108 { 00109 /* ufunc.reduceat(a, indicies[, axis, dtype, out]) */ 00110 int i; 00111 int nargs = PyTuple_GET_SIZE(args); 00112 PyObject *obj; 00113 00114 for (i = 0; i < nargs; i++) { 00115 obj = PyTuple_GET_ITEM(args, i); 00116 if (i == 0) { 00117 /* a and indicies */ 00118 *normal_args = PyTuple_GetSlice(args, 0, 2); 00119 } 00120 else if (i == 1) { 00121 /* Handled above, when i == 0. */ 00122 continue; 00123 } 00124 else if (i == 2) { 00125 /* axis */ 00126 PyDict_SetItemString(*normal_kwds, "axis", obj); 00127 } 00128 else if (i == 3) { 00129 /* dtype */ 00130 PyDict_SetItemString(*normal_kwds, "dtype", obj); 00131 } 00132 else { 00133 /* out */ 00134 PyDict_SetItemString(*normal_kwds, "out", obj); 00135 } 00136 } 00137 return; 00138 } 00139 00140 static void 00141 normalize_outer_args(PyUFuncObject *ufunc, PyObject *args, 00142 PyObject **normal_args, PyObject **normal_kwds) 00143 { 00144 /* ufunc.outer(A, B) 00145 * This has no kwds so we don't need to do any kwd stuff. 00146 */ 00147 *normal_args = PyTuple_GetSlice(args, 0, 2); 00148 return; 00149 } 00150 00151 static void 00152 normalize_at_args(PyUFuncObject *ufunc, PyObject *args, 00153 PyObject **normal_args, PyObject **normal_kwds) 00154 { 00155 /* ufunc.at(a, indices[, b]) */ 00156 int nargs = PyTuple_GET_SIZE(args); 00157 00158 *normal_args = PyTuple_GetSlice(args, 0, nargs); 00159 return; 00160 } 00161 00162 /* 00163 * Check a set of args for the `__numpy_ufunc__` method. If more than one of 00164 * the input arguments implements `__numpy_ufunc__`, they are tried in the 00165 * order: subclasses before superclasses, otherwise left to right. The first 00166 * routine returning something other than `NotImplemented` determines the 00167 * result. If all of the `__numpy_ufunc__` operations returns `NotImplemented`, 00168 * a `TypeError` is raised. 00169 * 00170 * Returns 0 on success and 1 on exception. On success, *result contains the 00171 * result of the operation, if any. If *result is NULL, there is no override. 00172 */ 00173 static int 00174 PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, 00175 PyObject *args, PyObject *kwds, 00176 PyObject **result, 00177 int nin) 00178 { 00179 int i; 00180 int override_pos; /* Position of override in args.*/ 00181 int j; 00182 00183 int nargs; 00184 int nout_kwd = 0; 00185 int out_kwd_is_tuple = 0; 00186 int noa = 0; /* Number of overriding args.*/ 00187 00188 PyObject *obj; 00189 PyObject *out_kwd_obj = NULL; 00190 PyObject *other_obj; 00191 00192 PyObject *method_name = NULL; 00193 PyObject *normal_args = NULL; /* normal_* holds normalized arguments. */ 00194 PyObject *normal_kwds = NULL; 00195 00196 PyObject *with_override[NPY_MAXARGS]; 00197 00198 /* Pos of each override in args */ 00199 int with_override_pos[NPY_MAXARGS]; 00200 00201 /* 2016-01-29: Disable for now in master -- can re-enable once details are 00202 * sorted out. All commented bits are tagged NUMPY_UFUNC_DISABLED. -njs 00203 */ 00204 result = NULL; 00205 return 0; 00206 00207 /* 00208 * Check inputs 00209 */ 00210 if (!PyTuple_Check(args)) { 00211 PyErr_SetString(PyExc_ValueError, 00212 "Internal Numpy error: call to PyUFunc_CheckOverride " 00213 "with non-tuple"); 00214 goto fail; 00215 } 00216 nargs = PyTuple_GET_SIZE(args); 00217 if (nargs > NPY_MAXARGS) { 00218 PyErr_SetString(PyExc_ValueError, 00219 "Internal Numpy error: too many arguments in call " 00220 "to PyUFunc_CheckOverride"); 00221 goto fail; 00222 } 00223 00224 /* be sure to include possible 'out' keyword argument. */ 00225 if ((kwds)&& (PyDict_CheckExact(kwds))) { 00226 out_kwd_obj = PyDict_GetItemString(kwds, "out"); 00227 if (out_kwd_obj != NULL) { 00228 out_kwd_is_tuple = PyTuple_CheckExact(out_kwd_obj); 00229 if (out_kwd_is_tuple) { 00230 nout_kwd = PyTuple_GET_SIZE(out_kwd_obj); 00231 } 00232 else { 00233 nout_kwd = 1; 00234 } 00235 } 00236 } 00237 00238 for (i = 0; i < nargs + nout_kwd; ++i) { 00239 if (i < nargs) { 00240 obj = PyTuple_GET_ITEM(args, i); 00241 } 00242 else { 00243 if (out_kwd_is_tuple) { 00244 obj = PyTuple_GET_ITEM(out_kwd_obj, i-nargs); 00245 } 00246 else { 00247 obj = out_kwd_obj; 00248 } 00249 } 00250 /* 00251 * TODO: could use PyArray_GetAttrString_SuppressException if it 00252 * weren't private to multiarray.so 00253 */ 00254 if (PyArray_CheckExact(obj) || PyArray_IsScalar(obj, Generic) || 00255 _is_basic_python_type(obj)) { 00256 continue; 00257 } 00258 if (PyObject_HasAttrString(obj, "__numpy_ufunc__")) { 00259 with_override[noa] = obj; 00260 with_override_pos[noa] = i; 00261 ++noa; 00262 } 00263 } 00264 00265 /* No overrides, bail out.*/ 00266 if (noa == 0) { 00267 *result = NULL; 00268 return 0; 00269 } 00270 00271 method_name = PyUString_FromString(method); 00272 if (method_name == NULL) { 00273 goto fail; 00274 } 00275 00276 /* 00277 * Normalize ufunc arguments. 00278 */ 00279 00280 /* Build new kwds */ 00281 if (kwds && PyDict_CheckExact(kwds)) { 00282 normal_kwds = PyDict_Copy(kwds); 00283 } 00284 else { 00285 normal_kwds = PyDict_New(); 00286 } 00287 if (normal_kwds == NULL) { 00288 goto fail; 00289 } 00290 00291 /* decide what to do based on the method. */ 00292 /* ufunc.__call__ */ 00293 if (strcmp(method, "__call__") == 0) { 00294 normalize___call___args(ufunc, args, &normal_args, &normal_kwds, nin); 00295 } 00296 00297 /* ufunc.reduce */ 00298 else if (strcmp(method, "reduce") == 0) { 00299 normalize_reduce_args(ufunc, args, &normal_args, &normal_kwds); 00300 } 00301 00302 /* ufunc.accumulate */ 00303 else if (strcmp(method, "accumulate") == 0) { 00304 normalize_accumulate_args(ufunc, args, &normal_args, &normal_kwds); 00305 } 00306 00307 /* ufunc.reduceat */ 00308 else if (strcmp(method, "reduceat") == 0) { 00309 normalize_reduceat_args(ufunc, args, &normal_args, &normal_kwds); 00310 } 00311 00312 /* ufunc.outer */ 00313 else if (strcmp(method, "outer") == 0) { 00314 normalize_outer_args(ufunc, args, &normal_args, &normal_kwds); 00315 } 00316 00317 /* ufunc.at */ 00318 else if (strcmp(method, "at") == 0) { 00319 normalize_at_args(ufunc, args, &normal_args, &normal_kwds); 00320 } 00321 00322 if (normal_args == NULL) { 00323 goto fail; 00324 } 00325 00326 /* 00327 * Call __numpy_ufunc__ functions in correct order 00328 */ 00329 while (1) { 00330 PyObject *numpy_ufunc; 00331 PyObject *override_args; 00332 PyObject *override_obj; 00333 00334 override_obj = NULL; 00335 *result = NULL; 00336 00337 /* Choose an overriding argument */ 00338 for (i = 0; i < noa; i++) { 00339 obj = with_override[i]; 00340 if (obj == NULL) { 00341 continue; 00342 } 00343 00344 /* Get the first instance of an overriding arg.*/ 00345 override_pos = with_override_pos[i]; 00346 override_obj = obj; 00347 00348 /* Check for sub-types to the right of obj. */ 00349 for (j = i + 1; j < noa; j++) { 00350 other_obj = with_override[j]; 00351 if (PyObject_Type(other_obj) != PyObject_Type(obj) && 00352 PyObject_IsInstance(other_obj, 00353 PyObject_Type(override_obj))) { 00354 override_obj = NULL; 00355 break; 00356 } 00357 } 00358 00359 /* override_obj had no subtypes to the right. */ 00360 if (override_obj) { 00361 with_override[i] = NULL; /* We won't call this one again */ 00362 break; 00363 } 00364 } 00365 00366 /* Check if there is a method left to call */ 00367 if (!override_obj) { 00368 /* No acceptable override found. */ 00369 PyErr_SetString(PyExc_TypeError, 00370 "__numpy_ufunc__ not implemented for this type."); 00371 goto fail; 00372 } 00373 00374 /* Call the override */ 00375 numpy_ufunc = PyObject_GetAttrString(override_obj, 00376 "__numpy_ufunc__"); 00377 if (numpy_ufunc == NULL) { 00378 goto fail; 00379 } 00380 00381 override_args = Py_BuildValue("OOiO", ufunc, method_name, 00382 override_pos, normal_args); 00383 if (override_args == NULL) { 00384 Py_DECREF(numpy_ufunc); 00385 goto fail; 00386 } 00387 00388 *result = PyObject_Call(numpy_ufunc, override_args, normal_kwds); 00389 00390 Py_DECREF(numpy_ufunc); 00391 Py_DECREF(override_args); 00392 00393 if (*result == NULL) { 00394 /* Exception occurred */ 00395 goto fail; 00396 } 00397 else if (*result == Py_NotImplemented) { 00398 /* Try the next one */ 00399 Py_DECREF(*result); 00400 continue; 00401 } 00402 else { 00403 /* Good result. */ 00404 break; 00405 } 00406 } 00407 00408 /* Override found, return it. */ 00409 Py_XDECREF(method_name); 00410 Py_XDECREF(normal_args); 00411 Py_XDECREF(normal_kwds); 00412 return 0; 00413 00414 fail: 00415 Py_XDECREF(method_name); 00416 Py_XDECREF(normal_args); 00417 Py_XDECREF(normal_kwds); 00418 return 1; 00419 } 00420 #endif