admin管理员组

文章数量:1122832

I am new to Rust and I'm looking into using Burn to port some Python/Torch code that I have for a new statistical parametric method.

Baby step 1: I want to generate a (10, 1) tensor with random values generated from a Cauchy distribution with known parameters. The distributions in Burn are very limited, so I'm using statrs. By using statrs, I can get a Vec<f64> and then I can wrap that into a TensorData in Burn and thus generate a Tensor.

I added some type signatures, but Burn has Float rather than the specific f64 and I'm a bit confused by this. In fact, just for debugging purposes, I want to extract the data from the Burn tensor as a Vec<f64> to see it (I should see the same values from vec: Vec<f64>) but I am getting a runtime type incompatibility.

use rand::prelude::Distribution;
use statrs::distribution::Cauchy;
use rand_chacha::ChaCha8Rng;
use rand_core::SeedableRng;
use burn::tensor::{Tensor, TensorData, Float};
use burn::backend::Wgpu;

type Backend = Wgpu;

fn main() {
    // some global refs
    let device = Default::default();
    let mut rng: ChaCha8Rng = ChaCha8Rng::seed_from_u64(2);

    // create random vec using statrs, store in a Vec<f64>
    let dist: Cauchy = Cauchy::new(5.0, 2.0).unwrap();
    let vec: Vec<f64> = dist.sample_iter(&mut rng).take(10).collect();

    // wrap this into a Burn tensor
    let td: TensorData = TensorData::new(vec, [10, 1]);
    let tensor: Tensor<Backend, 2, Float> = Tensor::<Backend, 2, Float>::from_data(td, &device);

    print!("{:?}\n", tensor.to_data().to_vec::<f64>().unwrap());
}

When running above, I get

thread 'main' panicked at src/main.rs:23:55:
called `Result::unwrap()` on an `Err` value: TypeMismatch("Invalid target element type 
(expected F32, got F64)")

Using to_vec::<f32> works, but I would like the Burn tensor to have f64 values (torch has this) as the error seems to imply that I lost precision at some point - no great.

Is storing f64 in a Burn tensor possible?

I am new to Rust and I'm looking into using Burn to port some Python/Torch code that I have for a new statistical parametric method.

Baby step 1: I want to generate a (10, 1) tensor with random values generated from a Cauchy distribution with known parameters. The distributions in Burn are very limited, so I'm using statrs. By using statrs, I can get a Vec<f64> and then I can wrap that into a TensorData in Burn and thus generate a Tensor.

I added some type signatures, but Burn has Float rather than the specific f64 and I'm a bit confused by this. In fact, just for debugging purposes, I want to extract the data from the Burn tensor as a Vec<f64> to see it (I should see the same values from vec: Vec<f64>) but I am getting a runtime type incompatibility.

use rand::prelude::Distribution;
use statrs::distribution::Cauchy;
use rand_chacha::ChaCha8Rng;
use rand_core::SeedableRng;
use burn::tensor::{Tensor, TensorData, Float};
use burn::backend::Wgpu;

type Backend = Wgpu;

fn main() {
    // some global refs
    let device = Default::default();
    let mut rng: ChaCha8Rng = ChaCha8Rng::seed_from_u64(2);

    // create random vec using statrs, store in a Vec<f64>
    let dist: Cauchy = Cauchy::new(5.0, 2.0).unwrap();
    let vec: Vec<f64> = dist.sample_iter(&mut rng).take(10).collect();

    // wrap this into a Burn tensor
    let td: TensorData = TensorData::new(vec, [10, 1]);
    let tensor: Tensor<Backend, 2, Float> = Tensor::<Backend, 2, Float>::from_data(td, &device);

    print!("{:?}\n", tensor.to_data().to_vec::<f64>().unwrap());
}

When running above, I get

thread 'main' panicked at src/main.rs:23:55:
called `Result::unwrap()` on an `Err` value: TypeMismatch("Invalid target element type 
(expected F32, got F64)")

Using to_vec::<f32> works, but I would like the Burn tensor to have f64 values (torch has this) as the error seems to imply that I lost precision at some point - no great.

Is storing f64 in a Burn tensor possible?

Share Improve this question edited yesterday Bob Arnson 21.9k2 gold badges42 silver badges50 bronze badges asked yesterday carlosayamcarlosayam 1,3981 gold badge10 silver badges15 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 1

The type used for floating-point values is a function of the backend. At the time of writing, all backends default to f32, but you can use e.g. type Backend = Wgpu<f64> if you want to use f64 instead.

本文标签: machine learningUsing f64 in Burn tensor in RustStack Overflow