admin管理员组文章数量:1122832
I am new to Rust and I'm looking into using Burn to port some Python/Torch code that I have for a new statistical parametric method.
Baby step 1: I want to generate a (10, 1) tensor with random values generated from a Cauchy distribution with known parameters. The distributions in Burn are very limited, so I'm using statrs. By using statrs
, I can get a Vec<f64>
and then I can wrap that into a TensorData
in Burn and thus generate a Tensor
.
I added some type signatures, but Burn has Float
rather than the specific f64
and I'm a bit confused by this. In fact, just for debugging purposes, I want to extract the data from the Burn tensor as a Vec<f64>
to see it (I should see the same values from vec: Vec<f64>
) but I am getting a runtime type incompatibility.
use rand::prelude::Distribution;
use statrs::distribution::Cauchy;
use rand_chacha::ChaCha8Rng;
use rand_core::SeedableRng;
use burn::tensor::{Tensor, TensorData, Float};
use burn::backend::Wgpu;
type Backend = Wgpu;
fn main() {
// some global refs
let device = Default::default();
let mut rng: ChaCha8Rng = ChaCha8Rng::seed_from_u64(2);
// create random vec using statrs, store in a Vec<f64>
let dist: Cauchy = Cauchy::new(5.0, 2.0).unwrap();
let vec: Vec<f64> = dist.sample_iter(&mut rng).take(10).collect();
// wrap this into a Burn tensor
let td: TensorData = TensorData::new(vec, [10, 1]);
let tensor: Tensor<Backend, 2, Float> = Tensor::<Backend, 2, Float>::from_data(td, &device);
print!("{:?}\n", tensor.to_data().to_vec::<f64>().unwrap());
}
When running above, I get
thread 'main' panicked at src/main.rs:23:55:
called `Result::unwrap()` on an `Err` value: TypeMismatch("Invalid target element type
(expected F32, got F64)")
Using to_vec::<f32>
works, but I would like the Burn tensor to have f64 values (torch has this) as the error seems to imply that I lost precision at some point - no great.
Is storing f64
in a Burn tensor possible?
I am new to Rust and I'm looking into using Burn to port some Python/Torch code that I have for a new statistical parametric method.
Baby step 1: I want to generate a (10, 1) tensor with random values generated from a Cauchy distribution with known parameters. The distributions in Burn are very limited, so I'm using statrs. By using statrs
, I can get a Vec<f64>
and then I can wrap that into a TensorData
in Burn and thus generate a Tensor
.
I added some type signatures, but Burn has Float
rather than the specific f64
and I'm a bit confused by this. In fact, just for debugging purposes, I want to extract the data from the Burn tensor as a Vec<f64>
to see it (I should see the same values from vec: Vec<f64>
) but I am getting a runtime type incompatibility.
use rand::prelude::Distribution;
use statrs::distribution::Cauchy;
use rand_chacha::ChaCha8Rng;
use rand_core::SeedableRng;
use burn::tensor::{Tensor, TensorData, Float};
use burn::backend::Wgpu;
type Backend = Wgpu;
fn main() {
// some global refs
let device = Default::default();
let mut rng: ChaCha8Rng = ChaCha8Rng::seed_from_u64(2);
// create random vec using statrs, store in a Vec<f64>
let dist: Cauchy = Cauchy::new(5.0, 2.0).unwrap();
let vec: Vec<f64> = dist.sample_iter(&mut rng).take(10).collect();
// wrap this into a Burn tensor
let td: TensorData = TensorData::new(vec, [10, 1]);
let tensor: Tensor<Backend, 2, Float> = Tensor::<Backend, 2, Float>::from_data(td, &device);
print!("{:?}\n", tensor.to_data().to_vec::<f64>().unwrap());
}
When running above, I get
thread 'main' panicked at src/main.rs:23:55:
called `Result::unwrap()` on an `Err` value: TypeMismatch("Invalid target element type
(expected F32, got F64)")
Using to_vec::<f32>
works, but I would like the Burn tensor to have f64 values (torch has this) as the error seems to imply that I lost precision at some point - no great.
Is storing f64
in a Burn tensor possible?
1 Answer
Reset to default 1The type used for floating-point values is a function of the backend. At the time of writing, all backends default to f32
, but you can use e.g. type Backend = Wgpu<f64>
if you want to use f64
instead.
本文标签: machine learningUsing f64 in Burn tensor in RustStack Overflow
版权声明:本文标题:machine learning - Using f64 in Burn tensor in Rust - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1736281078a1926204.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论