numpy  2.0.0
src/private/ufunc_override.h
Go to the documentation of this file.
00001 #ifndef __UFUNC_OVERRIDE_H
00002 #define __UFUNC_OVERRIDE_H
00003 #include <npy_config.h>
00004 #include "numpy/arrayobject.h"
00005 #include "common.h"
00006 #include <string.h>
00007 #include "numpy/ufuncobject.h"
00008 
00009 static void
00010 normalize___call___args(PyUFuncObject *ufunc, PyObject *args,
00011                     PyObject **normal_args, PyObject **normal_kwds,
00012                     int nin)
00013 {
00014     /* ufunc.__call__(*args, **kwds) */
00015     int nargs = PyTuple_GET_SIZE(args);
00016     PyObject *obj = PyDict_GetItemString(*normal_kwds, "sig");
00017 
00018     /* ufuncs accept 'sig' or 'signature' normalize to 'signature' */
00019     if (obj != NULL) {
00020         Py_INCREF(obj);
00021         PyDict_SetItemString(*normal_kwds, "signature", obj);
00022         PyDict_DelItemString(*normal_kwds, "sig");
00023     }
00024 
00025     *normal_args = PyTuple_GetSlice(args, 0, nin);
00026 
00027     /* If we have more args than nin, they must be the output variables.*/
00028     if (nargs > nin) {
00029         if ((nargs - nin) == 1) {
00030             obj = PyTuple_GET_ITEM(args, nargs - 1);
00031             PyDict_SetItemString(*normal_kwds, "out", obj);
00032         }
00033         else {
00034             obj = PyTuple_GetSlice(args, nin, nargs);
00035             PyDict_SetItemString(*normal_kwds, "out", obj);
00036             Py_DECREF(obj);
00037         }
00038     }
00039 }
00040 
00041 static void
00042 normalize_reduce_args(PyUFuncObject *ufunc, PyObject *args,
00043                   PyObject **normal_args, PyObject **normal_kwds)
00044 {
00045      /* ufunc.reduce(a[, axis, dtype, out, keepdims]) */
00046     int nargs = PyTuple_GET_SIZE(args);
00047     int i;
00048     PyObject *obj;
00049 
00050     for (i = 0; i < nargs; i++) {
00051         obj = PyTuple_GET_ITEM(args, i);
00052         if (i == 0) {
00053             *normal_args = PyTuple_GetSlice(args, 0, 1);
00054         }
00055         else if (i == 1) {
00056             /* axis */
00057             PyDict_SetItemString(*normal_kwds, "axis", obj);
00058         }
00059         else if (i == 2) {
00060             /* dtype */
00061             PyDict_SetItemString(*normal_kwds, "dtype", obj);
00062         }
00063         else if (i == 3) {
00064             /* out */
00065             PyDict_SetItemString(*normal_kwds, "out", obj);
00066         }
00067         else {
00068             /* keepdims */
00069             PyDict_SetItemString(*normal_kwds, "keepdims", obj);
00070         }
00071     }
00072     return;
00073 }
00074 
00075 static void
00076 normalize_accumulate_args(PyUFuncObject *ufunc, PyObject *args,
00077                       PyObject **normal_args, PyObject **normal_kwds)
00078 {
00079      /* ufunc.accumulate(a[, axis, dtype, out]) */
00080     int nargs = PyTuple_GET_SIZE(args);
00081     int i;
00082     PyObject *obj;
00083 
00084     for (i = 0; i < nargs; i++) {
00085         obj = PyTuple_GET_ITEM(args, i);
00086         if (i == 0) {
00087             *normal_args = PyTuple_GetSlice(args, 0, 1);
00088         }
00089         else if (i == 1) {
00090             /* axis */
00091             PyDict_SetItemString(*normal_kwds, "axis", obj);
00092         }
00093         else if (i == 2) {
00094             /* dtype */
00095             PyDict_SetItemString(*normal_kwds, "dtype", obj);
00096         }
00097         else {
00098             /* out */
00099             PyDict_SetItemString(*normal_kwds, "out", obj);
00100         }
00101     }
00102     return;
00103 }
00104 
00105 static void
00106 normalize_reduceat_args(PyUFuncObject *ufunc, PyObject *args,
00107                     PyObject **normal_args, PyObject **normal_kwds)
00108 {
00109      /* ufunc.reduceat(a, indicies[, axis, dtype, out]) */
00110     int i;
00111     int nargs = PyTuple_GET_SIZE(args);
00112     PyObject *obj;
00113 
00114     for (i = 0; i < nargs; i++) {
00115         obj = PyTuple_GET_ITEM(args, i);
00116         if (i == 0) {
00117             /* a and indicies */
00118             *normal_args = PyTuple_GetSlice(args, 0, 2);
00119         }
00120         else if (i == 1) {
00121             /* Handled above, when i == 0. */
00122             continue;
00123         }
00124         else if (i == 2) {
00125             /* axis */
00126             PyDict_SetItemString(*normal_kwds, "axis", obj);
00127         }
00128         else if (i == 3) {
00129             /* dtype */
00130             PyDict_SetItemString(*normal_kwds, "dtype", obj);
00131         }
00132         else {
00133             /* out */
00134             PyDict_SetItemString(*normal_kwds, "out", obj);
00135         }
00136     }
00137     return;
00138 }
00139 
00140 static void
00141 normalize_outer_args(PyUFuncObject *ufunc, PyObject *args,
00142                     PyObject **normal_args, PyObject **normal_kwds)
00143 {
00144     /* ufunc.outer(A, B)
00145      * This has no kwds so we don't need to do any kwd stuff.
00146      */
00147     *normal_args = PyTuple_GetSlice(args, 0, 2);
00148     return;
00149 }
00150 
00151 static void
00152 normalize_at_args(PyUFuncObject *ufunc, PyObject *args,
00153                   PyObject **normal_args, PyObject **normal_kwds)
00154 {
00155      /* ufunc.at(a, indices[, b]) */
00156     int nargs = PyTuple_GET_SIZE(args);
00157 
00158     *normal_args = PyTuple_GetSlice(args, 0, nargs);
00159     return;
00160 }
00161 
00162 /*
00163  * Check a set of args for the `__numpy_ufunc__` method.  If more than one of
00164  * the input arguments implements `__numpy_ufunc__`, they are tried in the
00165  * order: subclasses before superclasses, otherwise left to right. The first
00166  * routine returning something other than `NotImplemented` determines the
00167  * result. If all of the `__numpy_ufunc__` operations returns `NotImplemented`,
00168  * a `TypeError` is raised.
00169  *
00170  * Returns 0 on success and 1 on exception. On success, *result contains the
00171  * result of the operation, if any. If *result is NULL, there is no override.
00172  */
00173 static int
00174 PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
00175                       PyObject *args, PyObject *kwds,
00176                       PyObject **result,
00177                       int nin)
00178 {
00179     int i;
00180     int override_pos; /* Position of override in args.*/
00181     int j;
00182 
00183     int nargs;
00184     int nout_kwd = 0;
00185     int out_kwd_is_tuple = 0;
00186     int noa = 0; /* Number of overriding args.*/
00187 
00188     PyObject *obj;
00189     PyObject *out_kwd_obj = NULL;
00190     PyObject *other_obj;
00191 
00192     PyObject *method_name = NULL;
00193     PyObject *normal_args = NULL; /* normal_* holds normalized arguments. */
00194     PyObject *normal_kwds = NULL;
00195 
00196     PyObject *with_override[NPY_MAXARGS];
00197 
00198     /* Pos of each override in args */
00199     int with_override_pos[NPY_MAXARGS];
00200 
00201     /* 2016-01-29: Disable for now in master -- can re-enable once details are
00202      * sorted out. All commented bits are tagged NUMPY_UFUNC_DISABLED. -njs
00203      */
00204     result = NULL;
00205     return 0;
00206 
00207     /*
00208      * Check inputs
00209      */
00210     if (!PyTuple_Check(args)) {
00211         PyErr_SetString(PyExc_ValueError,
00212                         "Internal Numpy error: call to PyUFunc_CheckOverride "
00213                         "with non-tuple");
00214         goto fail;
00215     }
00216     nargs = PyTuple_GET_SIZE(args);
00217     if (nargs > NPY_MAXARGS) {
00218         PyErr_SetString(PyExc_ValueError,
00219                         "Internal Numpy error: too many arguments in call "
00220                         "to PyUFunc_CheckOverride");
00221         goto fail;
00222     }
00223 
00224     /* be sure to include possible 'out' keyword argument. */
00225     if ((kwds)&& (PyDict_CheckExact(kwds))) {
00226         out_kwd_obj = PyDict_GetItemString(kwds, "out");
00227         if (out_kwd_obj != NULL) {
00228             out_kwd_is_tuple = PyTuple_CheckExact(out_kwd_obj);
00229             if (out_kwd_is_tuple) {
00230                 nout_kwd = PyTuple_GET_SIZE(out_kwd_obj);
00231             }
00232             else {
00233                 nout_kwd = 1;
00234             }
00235         }
00236     }
00237 
00238     for (i = 0; i < nargs + nout_kwd; ++i) {
00239         if (i < nargs) {
00240             obj = PyTuple_GET_ITEM(args, i);
00241         }
00242         else {
00243             if (out_kwd_is_tuple) {
00244                 obj = PyTuple_GET_ITEM(out_kwd_obj, i-nargs);
00245             }
00246             else {
00247                 obj = out_kwd_obj;
00248             }
00249         }
00250         /*
00251          * TODO: could use PyArray_GetAttrString_SuppressException if it
00252          * weren't private to multiarray.so
00253          */
00254         if (PyArray_CheckExact(obj) || PyArray_IsScalar(obj, Generic) ||
00255             _is_basic_python_type(obj)) {
00256             continue;
00257         }
00258         if (PyObject_HasAttrString(obj, "__numpy_ufunc__")) {
00259             with_override[noa] = obj;
00260             with_override_pos[noa] = i;
00261             ++noa;
00262         }
00263     }
00264 
00265     /* No overrides, bail out.*/
00266     if (noa == 0) {
00267         *result = NULL;
00268         return 0;
00269     }
00270 
00271     method_name = PyUString_FromString(method);
00272     if (method_name == NULL) {
00273         goto fail;
00274     }
00275 
00276     /*
00277      * Normalize ufunc arguments.
00278      */
00279 
00280     /* Build new kwds */
00281     if (kwds && PyDict_CheckExact(kwds)) {
00282         normal_kwds = PyDict_Copy(kwds);
00283     }
00284     else {
00285         normal_kwds = PyDict_New();
00286     }
00287     if (normal_kwds == NULL) {
00288         goto fail;
00289     }
00290 
00291     /* decide what to do based on the method. */
00292     /* ufunc.__call__ */
00293     if (strcmp(method, "__call__") == 0) {
00294         normalize___call___args(ufunc, args, &normal_args, &normal_kwds, nin);
00295     }
00296 
00297     /* ufunc.reduce */
00298     else if (strcmp(method, "reduce") == 0) {
00299         normalize_reduce_args(ufunc, args, &normal_args, &normal_kwds);
00300     }
00301 
00302     /* ufunc.accumulate */
00303     else if (strcmp(method, "accumulate") == 0) {
00304         normalize_accumulate_args(ufunc, args, &normal_args, &normal_kwds);
00305     }
00306 
00307     /* ufunc.reduceat */
00308     else if (strcmp(method, "reduceat") == 0) {
00309         normalize_reduceat_args(ufunc, args, &normal_args, &normal_kwds);
00310     }
00311 
00312     /* ufunc.outer */
00313     else if (strcmp(method, "outer") == 0) {
00314         normalize_outer_args(ufunc, args, &normal_args, &normal_kwds);
00315     }
00316 
00317     /* ufunc.at */
00318     else if (strcmp(method, "at") == 0) {
00319         normalize_at_args(ufunc, args, &normal_args, &normal_kwds);
00320     }
00321 
00322     if (normal_args == NULL) {
00323         goto fail;
00324     }
00325 
00326     /*
00327      * Call __numpy_ufunc__ functions in correct order
00328      */
00329     while (1) {
00330         PyObject *numpy_ufunc;
00331         PyObject *override_args;
00332         PyObject *override_obj;
00333 
00334         override_obj = NULL;
00335         *result = NULL;
00336 
00337         /* Choose an overriding argument */
00338         for (i = 0; i < noa; i++) {
00339             obj = with_override[i];
00340             if (obj == NULL) {
00341                 continue;
00342             }
00343 
00344             /* Get the first instance of an overriding arg.*/
00345             override_pos = with_override_pos[i];
00346             override_obj = obj;
00347 
00348             /* Check for sub-types to the right of obj. */
00349             for (j = i + 1; j < noa; j++) {
00350                 other_obj = with_override[j];
00351                 if (PyObject_Type(other_obj) != PyObject_Type(obj) &&
00352                     PyObject_IsInstance(other_obj,
00353                                         PyObject_Type(override_obj))) {
00354                     override_obj = NULL;
00355                     break;
00356                 }
00357             }
00358 
00359             /* override_obj had no subtypes to the right. */
00360             if (override_obj) {
00361                 with_override[i] = NULL; /* We won't call this one again */
00362                 break;
00363             }
00364         }
00365 
00366         /* Check if there is a method left to call */
00367         if (!override_obj) {
00368             /* No acceptable override found. */
00369             PyErr_SetString(PyExc_TypeError,
00370                             "__numpy_ufunc__ not implemented for this type.");
00371             goto fail;
00372         }
00373 
00374         /* Call the override */
00375         numpy_ufunc = PyObject_GetAttrString(override_obj,
00376                                              "__numpy_ufunc__");
00377         if (numpy_ufunc == NULL) {
00378             goto fail;
00379         }
00380 
00381         override_args = Py_BuildValue("OOiO", ufunc, method_name,
00382                                       override_pos, normal_args);
00383         if (override_args == NULL) {
00384             Py_DECREF(numpy_ufunc);
00385             goto fail;
00386         }
00387 
00388         *result = PyObject_Call(numpy_ufunc, override_args, normal_kwds);
00389 
00390         Py_DECREF(numpy_ufunc);
00391         Py_DECREF(override_args);
00392 
00393         if (*result == NULL) {
00394             /* Exception occurred */
00395             goto fail;
00396         }
00397         else if (*result == Py_NotImplemented) {
00398             /* Try the next one */
00399             Py_DECREF(*result);
00400             continue;
00401         }
00402         else {
00403             /* Good result. */
00404             break;
00405         }
00406     }
00407 
00408     /* Override found, return it. */
00409     Py_XDECREF(method_name);
00410     Py_XDECREF(normal_args);
00411     Py_XDECREF(normal_kwds);
00412     return 0;
00413 
00414 fail:
00415     Py_XDECREF(method_name);
00416     Py_XDECREF(normal_args);
00417     Py_XDECREF(normal_kwds);
00418     return 1;
00419 }
00420 #endif