admin管理员组文章数量:1296887
I have written a Polars plugin to calculate the symmetric correlation matrix inspired by Polars DS project.
On the Rust side I have:
#![allow(unused_imports)]
use polars::prelude::*;
use pyo3_polars::{derive::polars_expr, export::polars_core::random};
use serde::{Deserialize, Serialize};
use kendalls::tau_b;
#[derive(Deserialize)]
#[serde(rename_all = "lowercase")]
enum CorrMethod {
Pearson,
Spearman,
Kendall,
}
#[derive(Deserialize)]
#[serde(rename_all = "lowercase")]
enum RankMethodSerde {
Average,
Min,
Max,
Dense,
Ordinal,
}
impl From<RankMethodSerde> for RankMethod {
fn from(method: RankMethodSerde) -> Self {
match method {
RankMethodSerde::Average => RankMethod::Average,
RankMethodSerde::Min => RankMethod::Min,
RankMethodSerde::Max => RankMethod::Max,
RankMethodSerde::Dense => RankMethod::Dense,
RankMethodSerde::Ordinal => RankMethod::Ordinal,
}
}
}
#[derive(Deserialize)]
struct Kwargs {
method: CorrMethod,
min_periods: u32,
rank_method: Option<RankMethodSerde>,
descending_rank: Option<bool>,
}
#[polars_expr(output_type = Float32)]
pub fn corr(series: &[Series], kwargs: Kwargs) -> PolarsResult<Series> {
let col_name = series[0].name();
let (series_x, series_y) = {
let series_x_full = &series[0];
let series_y_full = &series[1];
let mask = series_x_full.is_not_null() & series_y_full.is_not_null();
if mask.sum().unwrap_or(0) < kwargs.min_periods {
return Ok(Float32Chunked::new(col_name.clone(), [None]).into_series());
}
(&series_x_full.filter(&mask)?, &series_y_full.filter(&mask)?)
};
let corr = match kwargs.method {
CorrMethod::Pearson => pearson_corr(
series_x.cast(&DataType::Float32)?.f32()?,
series_y.cast(&DataType::Float32)?.f32()?,
),
CorrMethod::Spearman => {
let rank_method: RankMethod = kwargs.rank_method.unwrap().into();
let descending_rank = kwargs.descending_rank.unwrap();
let ranked_series_x = &series_x.rank(
RankOptions {
method: rank_method,
descending: descending_rank,
},
None,
);
let ranked_series_y = &series_y.rank(
RankOptions {
method: rank_method,
descending: descending_rank,
},
None,
);
spearman_corr(
ranked_series_x.cast(&DataType::Float32)?.f32()?,
ranked_series_y.cast(&DataType::Float32)?.f32()?,
)
}
_ => unimplemented!(),
};
Ok(Series::from_vec(col_name.clone(), vec![corr]))
}
fn pearson_corr(array_x: &ChunkedArray<Float32Type>, array_y: &ChunkedArray<Float32Type>) -> f32 {
let mean_x = array_x.mean().unwrap();
let mean_y = array_y.mean().unwrap();
let s_x = array_x - mean_x;
let s_y = array_y - mean_y;
let s_xy = (&s_x * &s_y).sum().unwrap();
let s_x_sq = (&s_x * &s_x).sum().unwrap();
let s_y_sq = (&s_y * &s_y).sum().unwrap();
s_xy / (s_x_sq * s_y_sq).sqrt()
}
fn spearman_corr(
ranked_array_x: &ChunkedArray<Float32Type>,
ranked_array_y: &ChunkedArray<Float32Type>,
) -> f32 {
pearson_corr(ranked_array_x, ranked_array_y)
}
On the Python side I have:
from typing import Literal, Optional
from polars import col
import polars.selectors as cs
import polars as pl
from ..utils import polars_plugin, parse_to_expr
def _corr(
x_vec: pl.Expr,
y_vec: pl.Expr,
method: Literal["pearson", "kendall", "spearman"] = "pearson",
min_periods: int = 1,
rank_method: Literal["average", "min", "max", "dense", "ordinal"]
| None = "average",
descending_rank: bool | None = False,
rounding: int | None = None,
) -> pl.Expr:
corr_expr = polars_plugin(
"corr",
args=[parse_to_expr(x_vec), parse_to_expr(y_vec)],
kwargs={
"method": method,
"min_periods": min_periods,
"rank_method": rank_method,
"descending_rank": descending_rank,
},
returns_scalar=True,
changes_length=True,
)
return corr_expr.round(rounding) if rounding is not None else corr_expr
def polars_corr(
data: pl.DataFrame | pl.LazyFrame,
*,
method: Literal["pearson", "kendall", "spearman"] = "pearson",
min_periods: int = 1,
rank_method: Literal["average", "min", "max", "dense", "ordinal"]
| None = "average",
descending_rank: bool | None = False,
rounding: int | None = None,
) -> pl.DataFrame:
data = data.lazy()
data = data.select(cs.numeric()).cast(pl.Float32)
columns = data.collect_schema().names()
corr_frames = [
data.select(
[
pl.lit(x).alias("feature"),
*[
_corr(
x,
y,
method,
min_periods,
rank_method,
descending_rank,
rounding,
).alias(y)
for y in columns
],
]
)
for x in columns
]
# Execute all computations in parallel and combine
return pl.concat(pl.collect_all(corr_frames)).select(["feature", *columns])
The program matches the behavior of pd.DataFrame.corr()
and outputs the same table. I would like to gather feedback from the community:
- While the pearson method is blazingly fast, and significantly faster than that of pandas, the spearman method suffers. Is it because of the implementation of the
RankMethod
? How can I improve? - The current implementation to assemble the scalar
DataFrame
s on the Python side double work on the pairwise correlations. How can I avoid these unnecessary calculations?
本文标签: rustNonSymmetric Calculation for Correlation Matrix in Polars PluginStack Overflow
版权声明:本文标题:rust - Non-Symmetric Calculation for Correlation Matrix in Polars Plugin - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1741645344a2390154.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论