summaryrefslogtreecommitdiffstats
path: root/src/arrow/python/pyarrow/tests/test_cython.py
blob: e202b417a188561635b8749ca5f9ecef141a6ae5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import shutil
import subprocess
import sys

import pytest

import pyarrow as pa
import pyarrow.tests.util as test_util


here = os.path.dirname(os.path.abspath(__file__))
test_ld_path = os.environ.get('PYARROW_TEST_LD_PATH', '')
if os.name == 'posix':
    compiler_opts = ['-std=c++11']
else:
    compiler_opts = []


setup_template = """if 1:
    from setuptools import setup
    from Cython.Build import cythonize

    import numpy as np

    import pyarrow as pa

    ext_modules = cythonize({pyx_file!r})
    compiler_opts = {compiler_opts!r}
    custom_ld_path = {test_ld_path!r}

    for ext in ext_modules:
        # XXX required for numpy/numpyconfig.h,
        # included from arrow/python/api.h
        ext.include_dirs.append(np.get_include())
        ext.include_dirs.append(pa.get_include())
        ext.libraries.extend(pa.get_libraries())
        ext.library_dirs.extend(pa.get_library_dirs())
        if custom_ld_path:
            ext.library_dirs.append(custom_ld_path)
        ext.extra_compile_args.extend(compiler_opts)
        print("Extension module:",
              ext, ext.include_dirs, ext.libraries, ext.library_dirs)

    setup(
        ext_modules=ext_modules,
    )
"""


def check_cython_example_module(mod):
    arr = pa.array([1, 2, 3])
    assert mod.get_array_length(arr) == 3
    with pytest.raises(TypeError, match="not an array"):
        mod.get_array_length(None)

    scal = pa.scalar(123)
    cast_scal = mod.cast_scalar(scal, pa.utf8())
    assert cast_scal == pa.scalar("123")
    with pytest.raises(NotImplementedError,
                       match="casting scalars of type int64 to type list"):
        mod.cast_scalar(scal, pa.list_(pa.int64()))


@pytest.mark.cython
def test_cython_api(tmpdir):
    """
    Basic test for the Cython API.
    """
    # Fail early if cython is not found
    import cython  # noqa

    with tmpdir.as_cwd():
        # Set up temporary workspace
        pyx_file = 'pyarrow_cython_example.pyx'
        shutil.copyfile(os.path.join(here, pyx_file),
                        os.path.join(str(tmpdir), pyx_file))
        # Create setup.py file
        setup_code = setup_template.format(pyx_file=pyx_file,
                                           compiler_opts=compiler_opts,
                                           test_ld_path=test_ld_path)
        with open('setup.py', 'w') as f:
            f.write(setup_code)

        # ARROW-2263: Make environment with this pyarrow/ package first on the
        # PYTHONPATH, for local dev environments
        subprocess_env = test_util.get_modified_env_with_pythonpath()

        # Compile extension module
        subprocess.check_call([sys.executable, 'setup.py',
                               'build_ext', '--inplace'],
                              env=subprocess_env)

        # Check basic functionality
        orig_path = sys.path[:]
        sys.path.insert(0, str(tmpdir))
        try:
            mod = __import__('pyarrow_cython_example')
            check_cython_example_module(mod)
        finally:
            sys.path = orig_path

        # Check the extension module is loadable from a subprocess without
        # pyarrow imported first.
        code = """if 1:
            import sys

            mod = __import__({mod_name!r})
            arr = mod.make_null_array(5)
            assert mod.get_array_length(arr) == 5
            assert arr.null_count == 5
        """.format(mod_name='pyarrow_cython_example')

        if sys.platform == 'win32':
            delim, var = ';', 'PATH'
        else:
            delim, var = ':', 'LD_LIBRARY_PATH'

        subprocess_env[var] = delim.join(
            pa.get_library_dirs() + [subprocess_env.get(var, '')]
        )

        subprocess.check_call([sys.executable, '-c', code],
                              stdout=subprocess.PIPE,
                              env=subprocess_env)


@pytest.mark.cython
def test_visit_strings(tmpdir):
    with tmpdir.as_cwd():
        # Set up temporary workspace
        pyx_file = 'bound_function_visit_strings.pyx'
        shutil.copyfile(os.path.join(here, pyx_file),
                        os.path.join(str(tmpdir), pyx_file))
        # Create setup.py file
        setup_code = setup_template.format(pyx_file=pyx_file,
                                           compiler_opts=compiler_opts,
                                           test_ld_path=test_ld_path)
        with open('setup.py', 'w') as f:
            f.write(setup_code)

        subprocess_env = test_util.get_modified_env_with_pythonpath()

        # Compile extension module
        subprocess.check_call([sys.executable, 'setup.py',
                               'build_ext', '--inplace'],
                              env=subprocess_env)

    sys.path.insert(0, str(tmpdir))
    mod = __import__('bound_function_visit_strings')

    strings = ['a', 'b', 'c']
    visited = []
    mod._visit_strings(strings, visited.append)

    assert visited == strings

    with pytest.raises(ValueError, match="wtf"):
        def raise_on_b(s):
            if s == 'b':
                raise ValueError('wtf')

        mod._visit_strings(strings, raise_on_b)