Python: Add support for @ infix operator matrix multiplication

This differential revision implements the code for T56276

Reviewers: campbellbarton

Reviewed By: campbellbarton

Differential Revision: https://developer.blender.org/D3587
This commit is contained in:
2018-08-10 14:53:38 +02:00
parent 693ecdf7d3
commit aa5a96430e
4 changed files with 529 additions and 95 deletions

View File

@@ -2321,7 +2321,7 @@ static PyObject *Matrix_sub(PyObject *m1, PyObject *m2)
return Matrix_CreatePyObject(mat, mat1->num_col, mat1->num_row, Py_TYPE(mat1));
}
/*------------------------obj * obj------------------------------
* multiplication */
* element-wise multiplication */
static PyObject *matrix_mul_float(MatrixObject *mat, const float scalar)
{
float tmat[MATRIX_MAX_DIM * MATRIX_MAX_DIM];
@@ -2332,6 +2332,114 @@ static PyObject *matrix_mul_float(MatrixObject *mat, const float scalar)
static PyObject *Matrix_mul(PyObject *m1, PyObject *m2)
{
float scalar;
MatrixObject *mat1 = NULL, *mat2 = NULL;
if (MatrixObject_Check(m1)) {
mat1 = (MatrixObject *)m1;
if (BaseMath_ReadCallback(mat1) == -1)
return NULL;
}
if (MatrixObject_Check(m2)) {
mat2 = (MatrixObject *)m2;
if (BaseMath_ReadCallback(mat2) == -1)
return NULL;
}
if (mat1 && mat2) {
#ifdef USE_MATHUTILS_ELEM_MUL
/* MATRIX * MATRIX */
float mat[MATRIX_MAX_DIM * MATRIX_MAX_DIM];
if ((mat1->num_row != mat2->num_row) || (mat1->num_col != mat2->num_col)) {
PyErr_SetString(PyExc_ValueError,
"matrix1 * matrix2: matrix1 number of rows/columns "
"and the matrix2 number of rows/columns must be the same");
return NULL;
}
mul_vn_vnvn(mat, mat1->matrix, mat2->matrix, mat1->num_col * mat1->num_row);
return Matrix_CreatePyObject(mat, mat2->num_col, mat1->num_row, Py_TYPE(mat1));
#endif
}
else if (mat2) {
/*FLOAT/INT * MATRIX */
if (((scalar = PyFloat_AsDouble(m1)) == -1.0f && PyErr_Occurred()) == 0) {
return matrix_mul_float(mat2, scalar);
}
}
else if (mat1) {
/* MATRIX * FLOAT/INT */
if (((scalar = PyFloat_AsDouble(m2)) == -1.0f && PyErr_Occurred()) == 0) {
return matrix_mul_float(mat1, scalar);
}
}
PyErr_Format(PyExc_TypeError,
"Element-wise multiplication: "
"not supported between '%.200s' and '%.200s' types",
Py_TYPE(m1)->tp_name, Py_TYPE(m2)->tp_name);
return NULL;
}
/*------------------------obj *= obj------------------------------
* Inplace element-wise multiplication */
static PyObject *Matrix_imul(PyObject *m1, PyObject *m2)
{
float scalar;
MatrixObject *mat1 = NULL, *mat2 = NULL;
if (MatrixObject_Check(m1)) {
mat1 = (MatrixObject *)m1;
if (BaseMath_ReadCallback(mat1) == -1)
return NULL;
}
if (MatrixObject_Check(m2)) {
mat2 = (MatrixObject *)m2;
if (BaseMath_ReadCallback(mat2) == -1)
return NULL;
}
if (mat1 && mat2) {
#ifdef USE_MATHUTILS_ELEM_MUL
/* MATRIX *= MATRIX */
if ((mat1->num_row != mat2->num_row) || (mat1->num_col != mat2->num_col)) {
PyErr_SetString(PyExc_ValueError,
"matrix1 *= matrix2: matrix1 number of rows/columns "
"and the matrix2 number of rows/columns must be the same");
return NULL;
}
mul_vn_vn(mat1->matrix, mat2->matrix, mat1->num_col * mat1->num_row);
#else
PyErr_Format(PyExc_TypeError,
"Inplace element-wise multiplication: "
"not supported between '%.200s' and '%.200s' types",
Py_TYPE(m1)->tp_name, Py_TYPE(m2)->tp_name);
return NULL;
#endif
}
else if (mat1 && (((scalar = PyFloat_AsDouble(m2)) == -1.0f && PyErr_Occurred()) == 0)) {
/* MATRIX *= FLOAT/INT */
mul_vn_fl(mat1->matrix, mat1->num_row * mat1->num_col, scalar);
}
else {
PyErr_Format(PyExc_TypeError,
"Inplace element-wise multiplication: "
"not supported between '%.200s' and '%.200s' types",
Py_TYPE(m1)->tp_name, Py_TYPE(m2)->tp_name);
return NULL;
}
(void)BaseMath_WriteCallback(mat1);
Py_INCREF(m1);
return m1;
}
/*------------------------obj @ obj------------------------------
* matrix multiplication */
static PyObject *Matrix_matmul(PyObject *m1, PyObject *m2)
{
int vec_size;
MatrixObject *mat1 = NULL, *mat2 = NULL;
@@ -2348,15 +2456,15 @@ static PyObject *Matrix_mul(PyObject *m1, PyObject *m2)
}
if (mat1 && mat2) {
/* MATRIX * MATRIX */
/* MATRIX @ MATRIX */
float mat[MATRIX_MAX_DIM * MATRIX_MAX_DIM];
int col, row, item;
if (mat1->num_col != mat2->num_row) {
PyErr_SetString(PyExc_ValueError,
"matrix1 * matrix2: matrix1 number of columns "
"and the matrix2 number of rows must be the same");
"matrix1 * matrix2: matrix1 number of columns "
"and the matrix2 number of rows must be the same");
return NULL;
}
@@ -2372,14 +2480,8 @@ static PyObject *Matrix_mul(PyObject *m1, PyObject *m2)
return Matrix_CreatePyObject(mat, mat2->num_col, mat1->num_row, Py_TYPE(mat1));
}
else if (mat2) {
/*FLOAT/INT * MATRIX */
if (((scalar = PyFloat_AsDouble(m1)) == -1.0f && PyErr_Occurred()) == 0) {
return matrix_mul_float(mat2, scalar);
}
}
else if (mat1) {
/* MATRIX * VECTOR */
/* MATRIX @ VECTOR */
if (VectorObject_Check(m2)) {
VectorObject *vec2 = (VectorObject *)m2;
float tvec[MATRIX_MAX_DIM];
@@ -2398,21 +2500,70 @@ static PyObject *Matrix_mul(PyObject *m1, PyObject *m2)
return Vector_CreatePyObject(tvec, vec_size, Py_TYPE(m2));
}
/*FLOAT/INT * MATRIX */
else if (((scalar = PyFloat_AsDouble(m2)) == -1.0f && PyErr_Occurred()) == 0) {
return matrix_mul_float(mat1, scalar);
}
}
else {
BLI_assert(!"internal error");
}
PyErr_Format(PyExc_TypeError,
"Matrix multiplication: "
"not supported between '%.200s' and '%.200s' types",
Py_TYPE(m1)->tp_name, Py_TYPE(m2)->tp_name);
"Matrix multiplication: "
"not supported between '%.200s' and '%.200s' types",
Py_TYPE(m1)->tp_name, Py_TYPE(m2)->tp_name);
return NULL;
}
/*------------------------obj @= obj------------------------------
* inplace matrix multiplication */
static PyObject *Matrix_imatmul(PyObject *m1, PyObject *m2)
{
MatrixObject *mat1 = NULL, *mat2 = NULL;
if (MatrixObject_Check(m1)) {
mat1 = (MatrixObject *)m1;
if (BaseMath_ReadCallback(mat1) == -1)
return NULL;
}
if (MatrixObject_Check(m2)) {
mat2 = (MatrixObject *)m2;
if (BaseMath_ReadCallback(mat2) == -1)
return NULL;
}
if (mat1 && mat2) {
/* MATRIX @= MATRIX */
float mat[MATRIX_MAX_DIM * MATRIX_MAX_DIM];
int col, row, item;
if (mat1->num_col != mat2->num_row) {
PyErr_SetString(PyExc_ValueError,
"matrix1 * matrix2: matrix1 number of columns "
"and the matrix2 number of rows must be the same");
return NULL;
}
for (col = 0; col < mat2->num_col; col++) {
for (row = 0; row < mat1->num_row; row++) {
double dot = 0.0f;
for (item = 0; item < mat1->num_col; item++) {
dot += (double)(MATRIX_ITEM(mat1, row, item) * MATRIX_ITEM(mat2, item, col));
}
/* store in new matrix as overwriting original at this point will cause
* subsequent iterations to use incorrect values */
mat[(col * mat1->num_row) + row] = (float)dot;
}
}
/* copy matrix back */
memcpy(mat1->matrix, mat, mat1->num_row * mat1->num_col);
}
else {
PyErr_Format(PyExc_TypeError,
"Inplace matrix multiplication: "
"not supported between '%.200s' and '%.200s' types",
Py_TYPE(m1)->tp_name, Py_TYPE(m2)->tp_name);
return NULL;
}
(void)BaseMath_WriteCallback(mat1);
Py_INCREF(m1);
return m1;
}
/*-----------------PROTOCOL DECLARATIONS--------------------------*/
static PySequenceMethods Matrix_SeqMethods = {
@@ -2527,7 +2678,7 @@ static PyNumberMethods Matrix_NumMethods = {
NULL, /*nb_float*/
NULL, /* nb_inplace_add */
NULL, /* nb_inplace_subtract */
NULL, /* nb_inplace_multiply */
(binaryfunc) Matrix_imul, /* nb_inplace_multiply */
NULL, /* nb_inplace_remainder */
NULL, /* nb_inplace_power */
NULL, /* nb_inplace_lshift */
@@ -2540,6 +2691,8 @@ static PyNumberMethods Matrix_NumMethods = {
NULL, /* nb_inplace_floor_divide */
NULL, /* nb_inplace_true_divide */
NULL, /* nb_index */
(binaryfunc) Matrix_matmul, /* nb_matrix_multiply */
(binaryfunc) Matrix_imatmul, /* nb_inplace_matrix_multiply */
};
PyDoc_STRVAR(Matrix_translation_doc,