admin管理员组

文章数量:1344297

I have a function that ensures uniform sizes for grouped data by padding missing values with a fill_value. The function currently uses a for loop to populate the padded array.

Is there a better way to generate the padded array and get rid of the for loop using NumPy's builtin abilities?

Here is the function:

def ensure_uniform_groups(
        groups: np.ndarray,
        values: np.ndarray,
        fill_value: np.number = np.nan) -> tuple[np.ndarray, np.ndarray]:
    """
    Ensure uniform group lengths by padding each group to the same size.

    Args:
        groups : np.ndarray
            1D array of group identifiers, assumed to be consecutive.
        values : np.ndarray
            1D/2D array of values corresponding to the group identifiers.
        fill_value : np.number, optional
            Value to use for padding groups. Default is np.nan.

    Returns:
        tuple[np.ndarray, np.ndarray]
            A tuple containing uniform groups with padded values.
    """
    # set common type
    dtype = np.result_type(fill_value, values)

    # derive group infos
    n = groups.size
    mask = np.r_[True, groups[:-1] != groups[1:]]
    starts = np.arange(n)[mask]
    ends = np.r_[starts[1:] - 1, n-1]
    sizes = ends - starts + 1
    max_size = np.max(sizes)

    # check if data is uniform already
    if np.all(sizes == max_size):
        return groups, values

    # generate uniform arrays
    unique_groups = groups[starts]
    full_groups = np.repeat(unique_groups, max_size)
    full_values = np.full((full_groups.shape[0], values.shape[1]), fill_value=fill_value, dtype=dtype)
    for i, (ia, ie) in enumerate(np.column_stack([starts, ends+1])):
        ua = i * max_size
        ue = ua + ie-ia
        full_values[ua:ue] = values[ia:ie]
    return full_groups, full_values

Here is an example:

groups = np.array([1, 1, 1, 2, 2, 3])   # size by group should be 3
values = np.column_stack([groups*10, groups*100])
fill_value = np.nan
ugroups, uvalues = ensure_uniform_groups(groups, values, fill_value)
out = np.vstack([ugroups, uvalues.T])
print(out)
# [[  1.   1.   1.   2.   2.   2.   3.   3.   3.]
#  [ 10.  10.  10.  20.  20.  nan  30.  nan  nan]
#  [100. 100. 100. 200. 200.  nan 300.  nan  nan]]

本文标签: pythonEnsure uniform group sizes using NumPyStack Overflow