Python: Add support for @ infix operator matrix multiplication

This differential revision implements the code for T56276

Reviewers: campbellbarton

Reviewed By: campbellbarton

Differential Revision: https://developer.blender.org/D3587
This commit is contained in:
2018-08-10 14:53:38 +02:00
parent 693ecdf7d3
commit aa5a96430e
4 changed files with 529 additions and 95 deletions

View File

@@ -834,7 +834,7 @@ static PyObject *quat_mul_float(QuaternionObject *quat, const float scalar)
* multiplication */
static PyObject *Quaternion_mul(PyObject *q1, PyObject *q2)
{
float quat[QUAT_SIZE], scalar;
float scalar;
QuaternionObject *quat1 = NULL, *quat2 = NULL;
if (QuaternionObject_Check(q1)) {
@@ -848,9 +848,12 @@ static PyObject *Quaternion_mul(PyObject *q1, PyObject *q2)
return NULL;
}
if (quat1 && quat2) { /* QUAT * QUAT (cross product) */
mul_qt_qtqt(quat, quat1->quat, quat2->quat);
if (quat1 && quat2) { /* QUAT * QUAT (element-wise product) */
#ifdef USE_MATHUTILS_ELEM_MUL
float quat[QUAT_SIZE];
mul_vn_vnvn(quat, quat1->quat, quat2->quat, QUAT_SIZE);
return Quaternion_CreatePyObject(quat, Py_TYPE(q1));
#endif
}
/* the only case this can happen (for a supported type is "FLOAT * QUAT") */
else if (quat2) { /* FLOAT * QUAT */
@@ -858,17 +861,96 @@ static PyObject *Quaternion_mul(PyObject *q1, PyObject *q2)
return quat_mul_float(quat2, scalar);
}
}
else if (quat1) { /* QUAT * FLOAT */
if ((((scalar = PyFloat_AsDouble(q2)) == -1.0f && PyErr_Occurred()) == 0)) {
return quat_mul_float(quat1, scalar);
}
}
PyErr_Format(PyExc_TypeError,
"Element-wise multiplication: "
"not supported between '%.200s' and '%.200s' types",
Py_TYPE(q1)->tp_name, Py_TYPE(q2)->tp_name);
return NULL;
}
/*------------------------obj *= obj------------------------------
* inplace multiplication */
static PyObject *Quaternion_imul(PyObject *q1, PyObject *q2)
{
float scalar;
QuaternionObject *quat1 = NULL, *quat2 = NULL;
if (QuaternionObject_Check(q1)) {
quat1 = (QuaternionObject *)q1;
if (BaseMath_ReadCallback(quat1) == -1)
return NULL;
}
if (QuaternionObject_Check(q2)) {
quat2 = (QuaternionObject *)q2;
if (BaseMath_ReadCallback(quat2) == -1)
return NULL;
}
if (quat1 && quat2) { /* QUAT *= QUAT (inplace element-wise product) */
#ifdef USE_MATHUTILS_ELEM_MUL
mul_vn_vn(quat1->quat, quat2->quat, QUAT_SIZE);
#else
PyErr_Format(PyExc_TypeError,
"Inplace element-wise multiplication: "
"not supported between '%.200s' and '%.200s' types",
Py_TYPE(q1)->tp_name, Py_TYPE(q2)->tp_name);
return NULL;
#endif
}
else if (quat1 && (((scalar = PyFloat_AsDouble(q2)) == -1.0f && PyErr_Occurred()) == 0)) {
/* QUAT *= FLOAT */
mul_qt_fl(quat1->quat, scalar);
}
else {
PyErr_Format(PyExc_TypeError,
"Element-wise multiplication: "
"not supported between '%.200s' and '%.200s' types",
Py_TYPE(q1)->tp_name, Py_TYPE(q2)->tp_name);
return NULL;
}
(void)BaseMath_WriteCallback(quat1);
Py_INCREF(q1);
return q1;
}
/*------------------------obj @ obj------------------------------
* quaternion multiplication */
static PyObject *Quaternion_matmul(PyObject *q1, PyObject *q2)
{
float quat[QUAT_SIZE];
QuaternionObject *quat1 = NULL, *quat2 = NULL;
if (QuaternionObject_Check(q1)) {
quat1 = (QuaternionObject *)q1;
if (BaseMath_ReadCallback(quat1) == -1)
return NULL;
}
if (QuaternionObject_Check(q2)) {
quat2 = (QuaternionObject *)q2;
if (BaseMath_ReadCallback(quat2) == -1)
return NULL;
}
if (quat1 && quat2) { /* QUAT @ QUAT (cross product) */
mul_qt_qtqt(quat, quat1->quat, quat2->quat);
return Quaternion_CreatePyObject(quat, Py_TYPE(q1));
}
else if (quat1) {
/* QUAT * VEC */
/* QUAT @ VEC */
if (VectorObject_Check(q2)) {
VectorObject *vec2 = (VectorObject *)q2;
float tvec[3];
if (vec2->size != 3) {
PyErr_SetString(PyExc_ValueError,
"Vector multiplication: "
"only 3D vector rotations (with quats) "
"currently supported");
"Vector multiplication: "
"only 3D vector rotations (with quats) "
"currently supported");
return NULL;
}
if (BaseMath_ReadCallback(vec2) == -1) {
@@ -880,21 +962,48 @@ static PyObject *Quaternion_mul(PyObject *q1, PyObject *q2)
return Vector_CreatePyObject(tvec, 3, Py_TYPE(vec2));
}
/* QUAT * FLOAT */
else if ((((scalar = PyFloat_AsDouble(q2)) == -1.0f && PyErr_Occurred()) == 0)) {
return quat_mul_float(quat1, scalar);
}
}
else {
BLI_assert(!"internal error");
}
PyErr_Format(PyExc_TypeError,
"Quaternion multiplication: "
"not supported between '%.200s' and '%.200s' types",
Py_TYPE(q1)->tp_name, Py_TYPE(q2)->tp_name);
"Quaternion multiplication: "
"not supported between '%.200s' and '%.200s' types",
Py_TYPE(q1)->tp_name, Py_TYPE(q2)->tp_name);
return NULL;
}
/*------------------------obj @= obj------------------------------
* inplace quaternion multiplication */
static PyObject *Quaternion_imatmul(PyObject *q1, PyObject *q2)
{
float quat[QUAT_SIZE];
QuaternionObject *quat1 = NULL, *quat2 = NULL;
if (QuaternionObject_Check(q1)) {
quat1 = (QuaternionObject *)q1;
if (BaseMath_ReadCallback(quat1) == -1)
return NULL;
}
if (QuaternionObject_Check(q2)) {
quat2 = (QuaternionObject *)q2;
if (BaseMath_ReadCallback(quat2) == -1)
return NULL;
}
if (quat1 && quat2) { /* QUAT @ QUAT (cross product) */
mul_qt_qtqt(quat, quat1->quat, quat2->quat);
copy_qt_qt(quat1->quat, quat);
}
else {
PyErr_Format(PyExc_TypeError,
"Inplace quaternion multiplication: "
"not supported between '%.200s' and '%.200s' types",
Py_TYPE(q1)->tp_name, Py_TYPE(q2)->tp_name);
return NULL;
}
(void)BaseMath_WriteCallback(quat1);
Py_INCREF(q1);
return q1;
}
/* -obj
* returns the negative of this object*/
@@ -952,7 +1061,7 @@ static PyNumberMethods Quaternion_NumMethods = {
NULL, /*nb_float*/
NULL, /* nb_inplace_add */
NULL, /* nb_inplace_subtract */
NULL, /* nb_inplace_multiply */
(binaryfunc) Quaternion_imul, /* nb_inplace_multiply */
NULL, /* nb_inplace_remainder */
NULL, /* nb_inplace_power */
NULL, /* nb_inplace_lshift */
@@ -965,6 +1074,8 @@ static PyNumberMethods Quaternion_NumMethods = {
NULL, /* nb_inplace_floor_divide */
NULL, /* nb_inplace_true_divide */
NULL, /* nb_index */
(binaryfunc) Quaternion_matmul, /* nb_matrix_multiply */
(binaryfunc) Quaternion_imatmul, /* nb_inplace_matrix_multiply */
};
PyDoc_STRVAR(Quaternion_axis_doc,