admin管理员组

文章数量:1122846

What's the best way to write a python function that can be used with either float, Numpy, or Pandas data types and always returns the same data type as the arguments it was given. The catch is, the calculation includes one or more float values.

E.g. toy example:

def mycalc(x, a=1.0, b=1.0):
    return a * x + b

(I've simplified the problem a lot here as I would ideally want to have more than one input argument like x, but you can assume that the function is vectorized in the sense that it works with Numpy array arguments and Pandas series).

For Numpy arrays and Pandas Series this works fine because the dtype is dictated by the input arguments.

import numpy as np
x = np.array([1, 2, 3], dtype="float32")
print(mycalc(x).dtype)  # float32
import pandas as pd
x = pd.Series([1.0, 2.0, 3.0], dtype="float32")
print(mycalc(x).dtype)  # float32

But when using numpy floats of lower precision, the dtype is 'lifted' to float64, presumably due to the float arguments in the formula:

x = np.float32(1.0)
print(mycalc(x).dtype)  # float64

Ideally, I would like the function to work with Python floats, numpy scalars, numpy arrays, Pandas series, Jax arrays, and even Sympy symbolic variables if possible.

But I don't want to clutter up the function with too many additional statements to handle each case.

I tried this, which works with Numpy scalars but breaks when you provide arrays or series:

def mycalc(x, a=1.0, b=1.0):
    a = type(x)(a)
    b = type(x)(b)
    return a * x + b

assert isinstance(mycalc(1.0), float)
assert isinstance(mycalc(np.float32(1.0)), np.float32)
mycalc(np.array([1, 2, 3], dtype="float32"))  # raises TypeError: expected a sequence of integers or a single integer, got '1.0'

Also, there is an answer here to a similar question which uses a decorator function to make copies of the input argument, which is a nice idea, but this was only for extending the function from Numpy arrays to Pandas series and doesn't work with Python floats or Numpy scalars.

import functools

def apply_to_pandas(func):
    @functools.wraps(func)
    def wrapper_func(x, *args, **kwargs):
        if isinstance(x, (np.ndarray, list)):
            out = func(x, *args, **kwargs)
        else:
            out = x.copy(deep=False)
            out[:] = np.apply_along_axis(func, 0, x, *args, **kwargs)
        return out
    return wrapper_func

@apply_to_pandas
def mycalc(x, a=1.0, b=1.0):
    return a * x + b

mycalc(1.0) # TypeError: copy() got an unexpected keyword argument 'deep'

Update

As pointed out by @Dunes in the comments below, this is no longer a problem in Numpy versions 2.x as explained here in the Numpy 2.0 Migration Guide.

In the new version, (np.float32(1.0) + 1).dtype == "float32". Therefore the original function above returns a result of the same dtype as the input x.

What's the best way to write a python function that can be used with either float, Numpy, or Pandas data types and always returns the same data type as the arguments it was given. The catch is, the calculation includes one or more float values.

E.g. toy example:

def mycalc(x, a=1.0, b=1.0):
    return a * x + b

(I've simplified the problem a lot here as I would ideally want to have more than one input argument like x, but you can assume that the function is vectorized in the sense that it works with Numpy array arguments and Pandas series).

For Numpy arrays and Pandas Series this works fine because the dtype is dictated by the input arguments.

import numpy as np
x = np.array([1, 2, 3], dtype="float32")
print(mycalc(x).dtype)  # float32
import pandas as pd
x = pd.Series([1.0, 2.0, 3.0], dtype="float32")
print(mycalc(x).dtype)  # float32

But when using numpy floats of lower precision, the dtype is 'lifted' to float64, presumably due to the float arguments in the formula:

x = np.float32(1.0)
print(mycalc(x).dtype)  # float64

Ideally, I would like the function to work with Python floats, numpy scalars, numpy arrays, Pandas series, Jax arrays, and even Sympy symbolic variables if possible.

But I don't want to clutter up the function with too many additional statements to handle each case.

I tried this, which works with Numpy scalars but breaks when you provide arrays or series:

def mycalc(x, a=1.0, b=1.0):
    a = type(x)(a)
    b = type(x)(b)
    return a * x + b

assert isinstance(mycalc(1.0), float)
assert isinstance(mycalc(np.float32(1.0)), np.float32)
mycalc(np.array([1, 2, 3], dtype="float32"))  # raises TypeError: expected a sequence of integers or a single integer, got '1.0'

Also, there is an answer here to a similar question which uses a decorator function to make copies of the input argument, which is a nice idea, but this was only for extending the function from Numpy arrays to Pandas series and doesn't work with Python floats or Numpy scalars.

import functools

def apply_to_pandas(func):
    @functools.wraps(func)
    def wrapper_func(x, *args, **kwargs):
        if isinstance(x, (np.ndarray, list)):
            out = func(x, *args, **kwargs)
        else:
            out = x.copy(deep=False)
            out[:] = np.apply_along_axis(func, 0, x, *args, **kwargs)
        return out
    return wrapper_func

@apply_to_pandas
def mycalc(x, a=1.0, b=1.0):
    return a * x + b

mycalc(1.0) # TypeError: copy() got an unexpected keyword argument 'deep'

Update

As pointed out by @Dunes in the comments below, this is no longer a problem in Numpy versions 2.x as explained here in the Numpy 2.0 Migration Guide.

In the new version, (np.float32(1.0) + 1).dtype == "float32". Therefore the original function above returns a result of the same dtype as the input x.

Share Improve this question edited Nov 23, 2024 at 1:10 Bill asked Nov 21, 2024 at 16:06 BillBill 11.6k12 gold badges67 silver badges97 bronze badges 7
  • What version of numpy are you using? I have been unable to reproduce this with v2.1.0 on python 3.10. The dtype of scalars is preserved. – Dunes Commented Nov 21, 2024 at 17:08
  • @Dunes I am using numpy version 1.26.1 with Python 3.10.12. Can you confirm which part you can't reproduce? This result I am guessing: mycalc(np.float32(1.0)).dtype == np.float64 – Bill Commented Nov 21, 2024 at 18:05
  • 1 It might help if you clearly distinguished between type and dtype. Also look at the [source] for some numpy fuctions, especially ones that delegate to methods. They often check inputs and convert rhem as needed to valid arrays (preserving array subclassing as needed). There's a lot of conversion and method delegation going on behind the scene when using operators. – hpaulj Commented Nov 21, 2024 at 18:44
  • Thanks @hpaulj, I changed the question text to make it more clear that it is the data type (dtype) that I want to match not the object type (except in the case of Python floats). When you say I should read the docs for Numpy functions are you suggesting there's a away to change the default behaviour of operations so that it doesn't 'lift' the precision? – Bill Commented Nov 21, 2024 at 18:52
  • 1 I was able to reproduce with numpy 1.26.1. But not with version 2.0.2. So it would appear that the short answer is to all this is to upgrade to at least version 2 of numpy. And if you cannot, then explain why in the question. ie. (np.float32(1) + 1).dtype == np.dtype('float32') is true in version 2.x, but false in version 1.x – Dunes Commented Nov 22, 2024 at 19:22
 |  Show 2 more comments

3 Answers 3

Reset to default 1

I don't mean for this is be an authoritative answer but rather maybe something to think about and see if it helps get you farther along. What if you tried to rely on the more advanced types implementation of the "r" dunder methods that seem to be more nuanced and did something like this:

import numpy as np
import pandas as pd

def mycalc(x, a=1, b=1):
    foo = x * a + b
    return foo if type(foo) == type(x) else type(x)(foo)

print(type(mycalc(1)))
print(type(mycalc(1.0)))
print(type(mycalc(np.float32(1.0))))
print(type(mycalc(np.array([1, 2, 3], dtype="float32"))))
print(type(mycalc(pd.Series([1, 2, 3], dtype="float64"))))

That seems to give back:

<class 'int'>
<class 'float'>
<class 'numpy.float32'>
<class 'numpy.ndarray'>
<class 'pandas.core.series.Series'>

post a comment here so I will know you saw this and then I will remove this as again it is not really an authoritative answer in my opinion, just an idea.

This doesn't exactly solve the problem I posed since the desired data type must be specified here, but I think its a simple, robust solution to the problem, rather than trying to automatically do the conversions based on the input types.

def mycalc(x, a=1.0, b=1.0, float_type=float):
    a = float_type(a)
    b = float_type(b)
    return a * x + b

assert isinstance(mycalc(1.0), float)
assert type(mycalc(np.float32(1.0), float_type=np.float32)) == np.float32
x = np.array([1.0, 2.0, 3.0], dtype="float32")
assert mycalc(x, float_type=np.float32).dtype == np.float32
x = pd.Series([1.0, 2.0, 3.0], dtype="float32")
assert mycalc(x, float_type=np.float32).dtype == np.float32
x = pd.Series([1.0, 2.0, 3.0], dtype="float64")
assert mycalc(x, float_type=np.float64).dtype == np.float64

It would be nice if the conversions could be done by a decorator but since the floats are default values of keyword arguments, there's no way a decorator can change them.

I'm still looking for better solutions but I post this here because it might be a good workaround in some cases.

This works for numpy and pandas types. I'm not sure this is the best way though.

import numpy as np
import pandas as pd

def get_item_type(var):
    try:
        var = var.to_numpy()
    except AttributeError:
        pass
    try:
        t = type(var.flat[0])
    except AttributeError:
        t = type(var)
    return t


def mycalc(x, a=1.0, b=1.0):
    float_type = get_item_type(x)
    a = float_type(a)
    b = float_type(b)
    return a * x + b


assert isinstance(mycalc(1.0), float)
assert type(mycalc(np.float32(1.0))) == np.float32
x = np.array([1.0, 2.0, 3.0], dtype="float32")
assert mycalc(x).dtype == np.float32
x = pd.Series([1.0, 2.0, 3.0], dtype="float32")
assert mycalc(x).dtype == np.float32
x = pd.Series([1.0, 2.0, 3.0], dtype="float64")
assert mycalc(x).dtype == np.float64

本文标签: