admin管理员组

文章数量:1296887

I have written a Polars plugin to calculate the symmetric correlation matrix inspired by Polars DS project.

On the Rust side I have:

#![allow(unused_imports)]
use polars::prelude::*;
use pyo3_polars::{derive::polars_expr, export::polars_core::random};
use serde::{Deserialize, Serialize};

use kendalls::tau_b;

#[derive(Deserialize)]
#[serde(rename_all = "lowercase")]
enum CorrMethod {
    Pearson,
    Spearman,
    Kendall,
}

#[derive(Deserialize)]
#[serde(rename_all = "lowercase")]
enum RankMethodSerde {
    Average,
    Min,
    Max,
    Dense,
    Ordinal,
}

impl From<RankMethodSerde> for RankMethod {
    fn from(method: RankMethodSerde) -> Self {
        match method {
            RankMethodSerde::Average => RankMethod::Average,
            RankMethodSerde::Min => RankMethod::Min,
            RankMethodSerde::Max => RankMethod::Max,
            RankMethodSerde::Dense => RankMethod::Dense,
            RankMethodSerde::Ordinal => RankMethod::Ordinal,
        }
    }
}
#[derive(Deserialize)]
struct Kwargs {
    method: CorrMethod,
    min_periods: u32,
    rank_method: Option<RankMethodSerde>,
    descending_rank: Option<bool>,
}

#[polars_expr(output_type = Float32)]
pub fn corr(series: &[Series], kwargs: Kwargs) -> PolarsResult<Series> {
    let col_name = series[0].name();

    let (series_x, series_y) = {
        let series_x_full = &series[0];
        let series_y_full = &series[1];
        let mask = series_x_full.is_not_null() & series_y_full.is_not_null();
        if mask.sum().unwrap_or(0) < kwargs.min_periods {
            return Ok(Float32Chunked::new(col_name.clone(), [None]).into_series());
        }

        (&series_x_full.filter(&mask)?, &series_y_full.filter(&mask)?)
    };

    let corr = match kwargs.method {
        CorrMethod::Pearson => pearson_corr(
            series_x.cast(&DataType::Float32)?.f32()?,
            series_y.cast(&DataType::Float32)?.f32()?,
        ),
        CorrMethod::Spearman => {
            let rank_method: RankMethod = kwargs.rank_method.unwrap().into();
            let descending_rank = kwargs.descending_rank.unwrap();

            let ranked_series_x = &series_x.rank(
                RankOptions {
                    method: rank_method,
                    descending: descending_rank,
                },
                None,
            );
            let ranked_series_y = &series_y.rank(
                RankOptions {
                    method: rank_method,
                    descending: descending_rank,
                },
                None,
            );

            spearman_corr(
                ranked_series_x.cast(&DataType::Float32)?.f32()?,
                ranked_series_y.cast(&DataType::Float32)?.f32()?,
            )
        }
        _ => unimplemented!(),
    };

    Ok(Series::from_vec(col_name.clone(), vec![corr]))
}

fn pearson_corr(array_x: &ChunkedArray<Float32Type>, array_y: &ChunkedArray<Float32Type>) -> f32 {
    let mean_x = array_x.mean().unwrap();
    let mean_y = array_y.mean().unwrap();

    let s_x = array_x - mean_x;
    let s_y = array_y - mean_y;

    let s_xy = (&s_x * &s_y).sum().unwrap();

    let s_x_sq = (&s_x * &s_x).sum().unwrap();
    let s_y_sq = (&s_y * &s_y).sum().unwrap();

    s_xy / (s_x_sq * s_y_sq).sqrt()
}

fn spearman_corr(
    ranked_array_x: &ChunkedArray<Float32Type>,
    ranked_array_y: &ChunkedArray<Float32Type>,
) -> f32 {
    pearson_corr(ranked_array_x, ranked_array_y)
}

On the Python side I have:

from typing import Literal, Optional

from polars import col
import polars.selectors as cs

import polars as pl

from ..utils import polars_plugin, parse_to_expr


def _corr(
    x_vec: pl.Expr,
    y_vec: pl.Expr,
    method: Literal["pearson", "kendall", "spearman"] = "pearson",
    min_periods: int = 1,
    rank_method: Literal["average", "min", "max", "dense", "ordinal"]
    | None = "average",
    descending_rank: bool | None = False,
    rounding: int | None = None,
) -> pl.Expr:
    corr_expr = polars_plugin(
        "corr",
        args=[parse_to_expr(x_vec), parse_to_expr(y_vec)],
        kwargs={
            "method": method,
            "min_periods": min_periods,
            "rank_method": rank_method,
            "descending_rank": descending_rank,
        },
        returns_scalar=True,
        changes_length=True,
    )
    return corr_expr.round(rounding) if rounding is not None else corr_expr


def polars_corr(
    data: pl.DataFrame | pl.LazyFrame,
    *,
    method: Literal["pearson", "kendall", "spearman"] = "pearson",
    min_periods: int = 1,
    rank_method: Literal["average", "min", "max", "dense", "ordinal"]
    | None = "average",
    descending_rank: bool | None = False,
    rounding: int | None = None,
) -> pl.DataFrame:
    data = data.lazy()
    data = data.select(cs.numeric()).cast(pl.Float32)

    columns = data.collect_schema().names()

    corr_frames = [
        data.select(
            [
                pl.lit(x).alias("feature"),
                *[
                    _corr(
                        x,
                        y,
                        method,
                        min_periods,
                        rank_method,
                        descending_rank,
                        rounding,
                    ).alias(y)
                    for y in columns
                ],
            ]
        )
        for x in columns
    ]

    # Execute all computations in parallel and combine
    return pl.concat(pl.collect_all(corr_frames)).select(["feature", *columns])

The program matches the behavior of pd.DataFrame.corr() and outputs the same table. I would like to gather feedback from the community:

  1. While the pearson method is blazingly fast, and significantly faster than that of pandas, the spearman method suffers. Is it because of the implementation of the RankMethod? How can I improve?
  2. The current implementation to assemble the scalar DataFrames on the Python side double work on the pairwise correlations. How can I avoid these unnecessary calculations?

本文标签: rustNonSymmetric Calculation for Correlation Matrix in Polars PluginStack Overflow