作業中のメモ

よく「計算機」を使って作業をする.知らなかったことを中心にまとめるつもり.

PythonからCの関数を呼び出す(Cの関数の共有ライブラリ化) Part2

どうも,筆者です.

前回のつづきになる.

workspacememory.hatenablog.com

今回は,二つの配列の和,差,積を計算し,与えられた配列に結果を格納する関数をPythonから呼び出せるようにする.

Pythonで受け付けるオブジェクト型はlist型のみとする.

以下のサイトを参考にした.

docs.python.org

docs.python.org

docs.python.org

Cの関数定義

前回のコードに追加する形で実装していく.

ヘッダー

ヘッダーファイルの内容を以下に示す.

// custom.h
#ifndef CUSTOM_H__
#define CUSTOM_H__
#include <stdint.h>

extern int32_t add(int32_t x, int32_t y);
extern void set_value(int32_t val);
extern double newton_method(double x0, int32_t max_iter, double (*target_func)(double x));
// 追加部分
extern int32_t calc_array(int32_t len, char oprator, const double *left_term, const double *right_term, double *out);
// ここまで

#endif

関数

#include <stdint.h>
// 追加部分1
#include <stddef.h>
// ここまで
#include "custom.h"
#define DELTA (1e-3)

static int32_t g_val = 0;

int32_t add(int32_t x, int32_t y) {
    int32_t z;
    z = x + y + g_val;

    return z;
}

void set_value(int32_t val) {
    g_val = val;
}

double newton_method(double x0, int32_t max_iter, double (*target_func)(double x)) {
    int32_t iter;
    double old_x, new_x, df;
    new_x = x0;

    for (iter = 0; iter < max_iter; iter++) {
        old_x = new_x;
        df = ((*target_func)(old_x + (double)DELTA) - (*target_func)(old_x - (double)DELTA)) / (double)DELTA * 0.5;
        new_x = old_x - (*target_func)(old_x) / df;
    }

    return new_x;
}

// 追加部分2
int32_t calc_array(int32_t len, char operator, const double *left_term, const double *right_term, double *out) {
    int32_t ret = 1;
    int32_t idx;

    if ((NULL != left_term) && (NULL != right_term) && (NULL != out)) {
        ret = 0;

        switch((int)operator) {
            case ((int)'+'):
                for (idx = 0; idx < len; idx++) {
                    out[idx] = left_term[idx] + right_term[idx];
                }
                break;

            case ((int)'-'):
                for (idx = 0; idx < len; idx++) {
                    out[idx] = left_term[idx] - right_term[idx];
                }
                break;

            case ((int)'*'):
                for (idx = 0; idx < len; idx++) {
                    out[idx] = left_term[idx] * right_term[idx];
                }
                break;

            default:
                ret = 1;
                break;
        }
    }

    return ret;
}
// ここまで

Wrapper関数定義

Pythonで呼び出す際は,演算子と二つのリストを与え,結果をリスト形式で返す方式とする. また,ネストが深くなるため,所々で例外処理も実装している.

#include <stdint.h>
#include "Python.h"
#include "custom.h"

PyObject *wrapper_add(PyObject *self, PyObject *args) {
    int32_t x, y, z;
    PyObject *ret = NULL;

    if (PyArg_ParseTuple(args, "ii", &x, &y)) {
        z = add(x, y);
        ret = Py_BuildValue("i", z);
    }

    return ret;
}

PyObject *wrapper_set_value(PyObject *self, PyObject *args) {
    int32_t val;
    PyObject *ret = NULL;

    if (PyArg_ParseTuple(args, "i", &val)) {
        set_value(val);
        ret = Py_BuildValue("");
    }

    return ret;
}

static PyObject *py_object_function = NULL;
static double alternative_obj_func(double x) {
    double ret_val = x;
    PyObject *arg, *result;

    // set argument
    arg = Py_BuildValue("(d)", x);
    // call function
    result = PyEval_CallObject(py_object_function, arg);

    // check result
    if (result && PyFloat_Check(result)) {
        ret_val = PyFloat_AsDouble(result);
    }
    Py_XDECREF(result);
    Py_DECREF(arg);

    return ret_val;
}

PyObject *wrapper_newton_method(PyObject *self, PyObject *args) {
    double x0;
    double est_x;
    int32_t max_iter;
    PyObject *target_func = NULL;
    PyObject *ret = NULL;

    if (PyArg_ParseTuple(args, "diO", &x0, &max_iter, &target_func)) {
        if (!PyCallable_Check(target_func)) {
            PyErr_SetString(PyExc_TypeError, "Need a callable object!");
        }
        else {
            // set function pointer
            Py_INCREF(target_func);
            Py_XDECREF(py_object_function);
            py_object_function = target_func;
            // call function
            est_x = newton_method(x0, max_iter, alternative_obj_func);
            ret = Py_BuildValue("d", est_x);
        }
    }

    return ret;
}

// 追加部分1
PyObject *wrapper_calc_array(PyObject *self, PyObject *args) {
    int32_t idx;
    int32_t return_code;
    int size;
    char operator;
    double *work_left, *work_right, *work_out;
    PyObject *left_term, *right_term;
    PyObject *ret;
    work_left = NULL;
    work_right = NULL;
    work_out = NULL;
    ret = NULL;

    // check argument
    if (!PyArg_ParseTuple(args, "cOO", &operator, &left_term, &right_term)) {
        PyErr_SetString(PyExc_TypeError, "Invalid argument");
        goto EXIT_WRAPPER_CALC_ARRAY;
    }
    // check object type
    if (!PyList_Check(left_term) || !PyList_Check(right_term)) {
        PyErr_SetString(PyExc_TypeError, "Need to list object");
        goto EXIT_WRAPPER_CALC_ARRAY;
    }
    // get list length
    size = (int32_t)PyList_Size(left_term);
    // check list length
    if (size != (int32_t)PyList_Size(right_term)) {
        PyErr_SetString(PyExc_RuntimeError, "Size of list does not mismatch");
        goto EXIT_WRAPPER_CALC_ARRAY;
    }
    // malloc memory
    work_left = (double *)malloc(sizeof(double) * size);
    work_right = (double *)malloc(sizeof(double) * size);
    work_out = (double *)malloc(sizeof(double) * size);
    ret = PyList_New(size);
    // check memory
    if ((NULL == work_left) || (NULL == work_right) || (NULL == work_out) || (NULL == ret)) {
        PyErr_SetString(PyExc_MemoryError, "lack of memory");
        Py_XDECREF(ret); // consider ret is NULL
        ret = NULL;
        goto EXIT_WRAPPER_CALC_ARRAY;
    }
    // copy data
    for (idx = 0; idx < size; idx++) {
        // リストから取り出したPyObject型をdouble型に変換し,配列に格納
        work_left[idx] = PyFloat_AsDouble(PyList_GetItem(left_term, idx));
        work_right[idx] = PyFloat_AsDouble(PyList_GetItem(right_term, idx));
    }
    // calculate
    return_code = calc_array(size, operator, (const double *)work_left, (const double *)work_right, work_out);
    // check return code
    if (return_code == 0) {
        // copy result
        for (idx = 0; idx < size; idx++) {
            // double型をPyObject型に変換し,リストの該当箇所に格納
            PyList_SetItem(ret, idx, PyFloat_FromDouble(work_out[idx]));
        }
    }
    else {
        PyErr_SetString(PyExc_RuntimeError, "Calculation failed");
        Py_DECREF(ret); // 失敗した場合は,確保したリストを削除
        ret = NULL;
        goto EXIT_WRAPPER_CALC_ARRAY;
    }

EXIT_WRAPPER_CALC_ARRAY:
    if (NULL != work_left) {
        free(work_left);
    }
    if (NULL != work_right) {
        free(work_right);
    }
    if (NULL != work_out) {
        free(work_out);
    }

    return ret;
}
// ここまで

static PyMethodDef custom_methods[] = {
    {"add", wrapper_add, METH_VARARGS, NULL},
    {"set_value", wrapper_set_value, METH_VARARGS, NULL},
    {"newton", wrapper_newton_method, METH_VARARGS, NULL},
    // 追加部分2
    {"calc_array", wrapper_calc_array, METH_VARARGS, NULL},
    // ここまで
    {NULL, NULL, 0, NULL}
};

static struct PyModuleDef custommodule = {
    PyModuleDef_HEAD_INIT,
    "custommodule",
    "",
    -1,
    custom_methods,
};

PyMODINIT_FUNC PyInit_custommodule(void) {
    return PyModule_Create(&custommodule);
}

コンパイル

前回のShell Scriptと同一のため,割愛

動作確認

こちらも,前回の結果に加える形で実装した.

import custommodule as c_mod

# add function uses static variable
# initial value of static variable is 0
c = c_mod.add(2, 3)
print('c_mod.add(2, 3) = ', c) # 2 + 3 + 0 -> 5

c_mod.set_value(1) # set 1 to static variable
print('c_mod.set_value(1)')
c = c_mod.add(2, 3)
print('c_mod.add(2, 3) = ', c) # 2 + 3 + 1 -> 6

# newton function uses function pointer
f = lambda x: x * x - 2.0 # define function
est_x = c_mod.newton(2.0, 10, f)
print('c_mod.newton(2.0, 10, f) = ', est_x) # sqrt(2.0) -> 1.4142135...

# 追加部分
xs = [0.5, 2.1,  3.2]
ys = [0.5, 1.1, -3.2]
zs = c_mod.calc_array(b'+', xs, ys)
convert_f = lambda xs: ', '.join([str(val) for val in xs])
print("c_mod.calc_array(b'+', [{}], [{}]) = ".format(convert_f(xs), convert_f(ys)), zs)

# ==================
# exception examples
# ==================
# Invalid argument
try:
    c_mod.calc_array('+', xs, ys)
except Exception as e:
    print(e)
# Need to list object
try:
    c_mod.calc_array(b'+', tuple(xs), ys)
except Exception as e:
    print(e)
# Size of list does not mismatch
try:
    c_mod.calc_array(b'+', xs, ys + [1.1])
except Exception as e:
    print(e)
# Calculation failed
try:
    c_mod.calc_array(b'/', xs, ys)
except Exception as e:
    print(e)
# ここまで
c_mod.add(2, 3) =  5
c_mod.set_value(1)
c_mod.add(2, 3) =  6
c_mod.newton(2.0, 10, f) =  1.414213562373095
# 追加部分
c_mod.calc_array(b'+', [0.5, 2.1, 3.2], [0.5, 1.1, -3.2]) =  [1.0, 3.2, 0.0]
Invalid argument
Need to list object
Size of list does not mismatch
Calculation failed
# ここまで

期待通りの結果が得られている.

気になる点

動作はしたが,以下の2点が気になる.

  • 演算時,一時的にメモリ使用量が2倍以上になる.この対策はないのか.
  • Py_DECREFを利用して確保した領域を解放しているつもりだが,期待通りの動作をしているか.