diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-21 11:54:28 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-21 11:54:28 +0000 |
commit | e6918187568dbd01842d8d1d2c808ce16a894239 (patch) | |
tree | 64f88b554b444a49f656b6c656111a145cbbaa28 /src/arrow/python/pyarrow/tests/strategies.py | |
parent | Initial commit. (diff) | |
download | ceph-e6918187568dbd01842d8d1d2c808ce16a894239.tar.xz ceph-e6918187568dbd01842d8d1d2c808ce16a894239.zip |
Adding upstream version 18.2.2.upstream/18.2.2
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/arrow/python/pyarrow/tests/strategies.py')
-rw-r--r-- | src/arrow/python/pyarrow/tests/strategies.py | 419 |
1 files changed, 419 insertions, 0 deletions
diff --git a/src/arrow/python/pyarrow/tests/strategies.py b/src/arrow/python/pyarrow/tests/strategies.py new file mode 100644 index 000000000..d314785ff --- /dev/null +++ b/src/arrow/python/pyarrow/tests/strategies.py @@ -0,0 +1,419 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime + +import pytz +import hypothesis as h +import hypothesis.strategies as st +import hypothesis.extra.numpy as npst +import hypothesis.extra.pytz as tzst +import numpy as np + +import pyarrow as pa + + +# TODO(kszucs): alphanum_text, surrogate_text +custom_text = st.text( + alphabet=st.characters( + min_codepoint=0x41, + max_codepoint=0x7E + ) +) + +null_type = st.just(pa.null()) +bool_type = st.just(pa.bool_()) + +binary_type = st.just(pa.binary()) +string_type = st.just(pa.string()) +large_binary_type = st.just(pa.large_binary()) +large_string_type = st.just(pa.large_string()) +fixed_size_binary_type = st.builds( + pa.binary, + st.integers(min_value=0, max_value=16) +) +binary_like_types = st.one_of( + binary_type, + string_type, + large_binary_type, + large_string_type, + fixed_size_binary_type +) + +signed_integer_types = st.sampled_from([ + pa.int8(), + pa.int16(), + pa.int32(), + pa.int64() +]) +unsigned_integer_types = st.sampled_from([ + pa.uint8(), + pa.uint16(), + pa.uint32(), + pa.uint64() +]) +integer_types = st.one_of(signed_integer_types, unsigned_integer_types) + +floating_types = st.sampled_from([ + pa.float16(), + pa.float32(), + pa.float64() +]) +decimal128_type = st.builds( + pa.decimal128, + precision=st.integers(min_value=1, max_value=38), + scale=st.integers(min_value=1, max_value=38) +) +decimal256_type = st.builds( + pa.decimal256, + precision=st.integers(min_value=1, max_value=76), + scale=st.integers(min_value=1, max_value=76) +) +numeric_types = st.one_of(integer_types, floating_types, + decimal128_type, decimal256_type) + +date_types = st.sampled_from([ + pa.date32(), + pa.date64() +]) +time_types = st.sampled_from([ + pa.time32('s'), + pa.time32('ms'), + pa.time64('us'), + pa.time64('ns') +]) +timestamp_types = st.builds( + pa.timestamp, + unit=st.sampled_from(['s', 'ms', 'us', 'ns']), + tz=tzst.timezones() +) +duration_types = st.builds( + pa.duration, + st.sampled_from(['s', 'ms', 'us', 'ns']) +) +interval_types = st.sampled_from( + pa.month_day_nano_interval() +) +temporal_types = st.one_of( + date_types, + time_types, + timestamp_types, + duration_types, + interval_types +) + +primitive_types = st.one_of( + null_type, + bool_type, + numeric_types, + temporal_types, + binary_like_types +) + +metadata = st.dictionaries(st.text(), st.text()) + + +@st.composite +def fields(draw, type_strategy=primitive_types): + name = draw(custom_text) + typ = draw(type_strategy) + if pa.types.is_null(typ): + nullable = True + else: + nullable = draw(st.booleans()) + meta = draw(metadata) + return pa.field(name, type=typ, nullable=nullable, metadata=meta) + + +def list_types(item_strategy=primitive_types): + return ( + st.builds(pa.list_, item_strategy) | + st.builds(pa.large_list, item_strategy) | + st.builds( + pa.list_, + item_strategy, + st.integers(min_value=0, max_value=16) + ) + ) + + +@st.composite +def struct_types(draw, item_strategy=primitive_types): + fields_strategy = st.lists(fields(item_strategy)) + fields_rendered = draw(fields_strategy) + field_names = [field.name for field in fields_rendered] + # check that field names are unique, see ARROW-9997 + h.assume(len(set(field_names)) == len(field_names)) + return pa.struct(fields_rendered) + + +def dictionary_types(key_strategy=None, value_strategy=None): + key_strategy = key_strategy or signed_integer_types + value_strategy = value_strategy or st.one_of( + bool_type, + integer_types, + st.sampled_from([pa.float32(), pa.float64()]), + binary_type, + string_type, + fixed_size_binary_type, + ) + return st.builds(pa.dictionary, key_strategy, value_strategy) + + +@st.composite +def map_types(draw, key_strategy=primitive_types, + item_strategy=primitive_types): + key_type = draw(key_strategy) + h.assume(not pa.types.is_null(key_type)) + value_type = draw(item_strategy) + return pa.map_(key_type, value_type) + + +# union type +# extension type + + +def schemas(type_strategy=primitive_types, max_fields=None): + children = st.lists(fields(type_strategy), max_size=max_fields) + return st.builds(pa.schema, children) + + +all_types = st.deferred( + lambda: ( + primitive_types | + list_types() | + struct_types() | + dictionary_types() | + map_types() | + list_types(all_types) | + struct_types(all_types) + ) +) +all_fields = fields(all_types) +all_schemas = schemas(all_types) + + +_default_array_sizes = st.integers(min_value=0, max_value=20) + + +@st.composite +def _pylist(draw, value_type, size, nullable=True): + arr = draw(arrays(value_type, size=size, nullable=False)) + return arr.to_pylist() + + +@st.composite +def _pymap(draw, key_type, value_type, size, nullable=True): + length = draw(size) + keys = draw(_pylist(key_type, size=length, nullable=False)) + values = draw(_pylist(value_type, size=length, nullable=nullable)) + return list(zip(keys, values)) + + +@st.composite +def arrays(draw, type, size=None, nullable=True): + if isinstance(type, st.SearchStrategy): + ty = draw(type) + elif isinstance(type, pa.DataType): + ty = type + else: + raise TypeError('Type must be a pyarrow DataType') + + if isinstance(size, st.SearchStrategy): + size = draw(size) + elif size is None: + size = draw(_default_array_sizes) + elif not isinstance(size, int): + raise TypeError('Size must be an integer') + + if pa.types.is_null(ty): + h.assume(nullable) + value = st.none() + elif pa.types.is_boolean(ty): + value = st.booleans() + elif pa.types.is_integer(ty): + values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,))) + return pa.array(values, type=ty) + elif pa.types.is_floating(ty): + values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,))) + # Workaround ARROW-4952: no easy way to assert array equality + # in a NaN-tolerant way. + values[np.isnan(values)] = -42.0 + return pa.array(values, type=ty) + elif pa.types.is_decimal(ty): + # TODO(kszucs): properly limit the precision + # value = st.decimals(places=type.scale, allow_infinity=False) + h.reject() + elif pa.types.is_time(ty): + value = st.times() + elif pa.types.is_date(ty): + value = st.dates() + elif pa.types.is_timestamp(ty): + min_int64 = -(2**63) + max_int64 = 2**63 - 1 + min_datetime = datetime.datetime.fromtimestamp(min_int64 // 10**9) + max_datetime = datetime.datetime.fromtimestamp(max_int64 // 10**9) + try: + offset_hours = int(ty.tz) + tz = pytz.FixedOffset(offset_hours * 60) + except ValueError: + tz = pytz.timezone(ty.tz) + value = st.datetimes(timezones=st.just(tz), min_value=min_datetime, + max_value=max_datetime) + elif pa.types.is_duration(ty): + value = st.timedeltas() + elif pa.types.is_binary(ty) or pa.types.is_large_binary(ty): + value = st.binary() + elif pa.types.is_string(ty) or pa.types.is_large_string(ty): + value = st.text() + elif pa.types.is_fixed_size_binary(ty): + value = st.binary(min_size=ty.byte_width, max_size=ty.byte_width) + elif pa.types.is_list(ty): + value = _pylist(ty.value_type, size=size, nullable=nullable) + elif pa.types.is_large_list(ty): + value = _pylist(ty.value_type, size=size, nullable=nullable) + elif pa.types.is_fixed_size_list(ty): + value = _pylist(ty.value_type, size=ty.list_size, nullable=nullable) + elif pa.types.is_dictionary(ty): + values = _pylist(ty.value_type, size=size, nullable=nullable) + return pa.array(draw(values), type=ty) + elif pa.types.is_map(ty): + value = _pymap(ty.key_type, ty.item_type, size=_default_array_sizes, + nullable=nullable) + elif pa.types.is_struct(ty): + h.assume(len(ty) > 0) + fields, child_arrays = [], [] + for field in ty: + fields.append(field) + child_arrays.append(draw(arrays(field.type, size=size))) + return pa.StructArray.from_arrays(child_arrays, fields=fields) + else: + raise NotImplementedError(ty) + + if nullable: + value = st.one_of(st.none(), value) + values = st.lists(value, min_size=size, max_size=size) + + return pa.array(draw(values), type=ty) + + +@st.composite +def chunked_arrays(draw, type, min_chunks=0, max_chunks=None, chunk_size=None): + if isinstance(type, st.SearchStrategy): + type = draw(type) + + # TODO(kszucs): remove it, field metadata is not kept + h.assume(not pa.types.is_struct(type)) + + chunk = arrays(type, size=chunk_size) + chunks = st.lists(chunk, min_size=min_chunks, max_size=max_chunks) + + return pa.chunked_array(draw(chunks), type=type) + + +@st.composite +def record_batches(draw, type, rows=None, max_fields=None): + if isinstance(rows, st.SearchStrategy): + rows = draw(rows) + elif rows is None: + rows = draw(_default_array_sizes) + elif not isinstance(rows, int): + raise TypeError('Rows must be an integer') + + schema = draw(schemas(type, max_fields=max_fields)) + children = [draw(arrays(field.type, size=rows)) for field in schema] + # TODO(kszucs): the names and schema arguments are not consistent with + # Table.from_array's arguments + return pa.RecordBatch.from_arrays(children, names=schema) + + +@st.composite +def tables(draw, type, rows=None, max_fields=None): + if isinstance(rows, st.SearchStrategy): + rows = draw(rows) + elif rows is None: + rows = draw(_default_array_sizes) + elif not isinstance(rows, int): + raise TypeError('Rows must be an integer') + + schema = draw(schemas(type, max_fields=max_fields)) + children = [draw(arrays(field.type, size=rows)) for field in schema] + return pa.Table.from_arrays(children, schema=schema) + + +all_arrays = arrays(all_types) +all_chunked_arrays = chunked_arrays(all_types) +all_record_batches = record_batches(all_types) +all_tables = tables(all_types) + + +# Define the same rules as above for pandas tests by excluding certain types +# from the generation because of known issues. + +pandas_compatible_primitive_types = st.one_of( + null_type, + bool_type, + integer_types, + st.sampled_from([pa.float32(), pa.float64()]), + decimal128_type, + date_types, + time_types, + # Need to exclude timestamp and duration types otherwise hypothesis + # discovers ARROW-10210 + # timestamp_types, + # duration_types + interval_types, + binary_type, + string_type, + large_binary_type, + large_string_type, +) + +# Need to exclude floating point types otherwise hypothesis discovers +# ARROW-10211 +pandas_compatible_dictionary_value_types = st.one_of( + bool_type, + integer_types, + binary_type, + string_type, + fixed_size_binary_type, +) + + +def pandas_compatible_list_types( + item_strategy=pandas_compatible_primitive_types +): + # Need to exclude fixed size list type otherwise hypothesis discovers + # ARROW-10194 + return ( + st.builds(pa.list_, item_strategy) | + st.builds(pa.large_list, item_strategy) + ) + + +pandas_compatible_types = st.deferred( + lambda: st.one_of( + pandas_compatible_primitive_types, + pandas_compatible_list_types(pandas_compatible_primitive_types), + struct_types(pandas_compatible_primitive_types), + dictionary_types( + value_strategy=pandas_compatible_dictionary_value_types + ), + pandas_compatible_list_types(pandas_compatible_types), + struct_types(pandas_compatible_types) + ) +) |