-
Notifications
You must be signed in to change notification settings - Fork 482
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Wasserstein Generative Adversarial Network (#2660)
* Add files via upload Wasserstein Generative Adversarial Network * Delete examples/wgan/readme * Create README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update cli.rs * Update cli.rs * Update model.rs * Update training.rs * Update main.rs * Update model.rs * Update training.rs * Update training.rs * Update main.rs * Update training.rs * Update model.rs * Update training.rs * Update cli.rs * Update cli.rs * Update generating.rs * Update lib.rs * Update model.rs * Update training.rs * Update main.rs * Update generating.rs * Update model.rs * Update training.rs * Update generating.rs * Update model.rs * Update training.rs * Update training.rs * Update dataset.rs * Update generating.rs * Update model.rs * Update training.rs * Update training.rs * Update training.rs * Restructure as workspace example * Add support for single range slice (fixes clippy) * Update example usage + list --------- Co-authored-by: Guillaume Lagrange <[email protected]>
- Loading branch information
1 parent
ad81344
commit f630b3b
Showing
14 changed files
with
752 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[package] | ||
name = "wgan" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[features] | ||
ndarray = ["burn/ndarray"] | ||
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] | ||
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] | ||
ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] | ||
tch-cpu = ["burn/tch"] | ||
tch-gpu = ["burn/tch"] | ||
wgpu = ["burn/wgpu"] | ||
cuda-jit = ["burn/cuda-jit"] | ||
|
||
[dependencies] | ||
burn = { path = "../../crates/burn", features=["train", "vision"] } | ||
image = { workspace = true } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Wasserstein Generative Adversarial Network | ||
|
||
A burn implementation of examplar WGAN model to generate MNIST digits inspired by | ||
[the PyTorch implementation](https://bytepawn.com/training-a-pytorch-wasserstain-mnist-gan-on-google-colab.html). | ||
Please note that better performance maybe gained by adopting a convolution layer in | ||
[some other models](https://github.com/Lornatang/WassersteinGAN-PyTorch). | ||
|
||
## Usage | ||
|
||
|
||
## Training | ||
|
||
```sh | ||
# Cuda backend | ||
cargo run --example wgan-mnist --release --features cuda-jit | ||
|
||
# Wgpu backend | ||
cargo run --example wgan-mnist --release --features wgpu | ||
|
||
# Tch GPU backend | ||
export TORCH_CUDA_VERSION=cu121 # Set the cuda version | ||
cargo run --example wgan-mnist --release --features tch-gpu | ||
|
||
# Tch CPU backend | ||
cargo run --example wgan-mnist --release --features tch-cpu | ||
|
||
# NdArray backend (CPU) | ||
cargo run --example wgan-mnist --release --features ndarray # f32 - single thread | ||
cargo run --example wgan-mnist --release --features ndarray-blas-openblas # f32 - blas with openblas | ||
cargo run --example wgan-mnist --release --features ndarray-blas-netlib # f32 - blas with netlib | ||
``` | ||
|
||
|
||
### Generating | ||
|
||
To generate a sample of images, you can use `wgan-generate`. The same feature flags are used to select a backend. | ||
|
||
```sh | ||
cargo run --example wgan-generate --release --features cuda-jit | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
use burn::tensor::backend::Backend; | ||
|
||
pub fn launch<B: Backend>(device: B::Device) { | ||
wgan::infer::generate::<B>("/tmp/wgan-mnist", device); | ||
} | ||
|
||
#[cfg(any( | ||
feature = "ndarray", | ||
feature = "ndarray-blas-netlib", | ||
feature = "ndarray-blas-openblas", | ||
feature = "ndarray-blas-accelerate", | ||
))] | ||
mod ndarray { | ||
use burn::backend::{ | ||
ndarray::{NdArray, NdArrayDevice}, | ||
Autodiff, | ||
}; | ||
|
||
use crate::launch; | ||
|
||
pub fn run() { | ||
launch::<Autodiff<NdArray>>(NdArrayDevice::Cpu); | ||
} | ||
} | ||
|
||
#[cfg(feature = "tch-gpu")] | ||
mod tch_gpu { | ||
use burn::backend::{ | ||
libtorch::{LibTorch, LibTorchDevice}, | ||
Autodiff, | ||
}; | ||
|
||
use crate::launch; | ||
|
||
pub fn run() { | ||
#[cfg(not(target_os = "macos"))] | ||
let device = LibTorchDevice::Cuda(0); | ||
#[cfg(target_os = "macos")] | ||
let device = LibTorchDevice::Mps; | ||
|
||
launch::<Autodiff<LibTorch>>(device); | ||
} | ||
} | ||
|
||
#[cfg(feature = "tch-cpu")] | ||
mod tch_cpu { | ||
use burn::backend::{ | ||
libtorch::{LibTorch, LibTorchDevice}, | ||
Autodiff, | ||
}; | ||
|
||
use crate::launch; | ||
|
||
pub fn run() { | ||
launch::<Autodiff<LibTorch>>(LibTorchDevice::Cpu); | ||
} | ||
} | ||
|
||
#[cfg(feature = "wgpu")] | ||
mod wgpu { | ||
use crate::launch; | ||
use burn::backend::{wgpu::Wgpu, Autodiff}; | ||
|
||
pub fn run() { | ||
launch::<Autodiff<Wgpu>>(Default::default()); | ||
} | ||
} | ||
|
||
#[cfg(feature = "cuda-jit")] | ||
mod cuda_jit { | ||
use crate::launch; | ||
use burn::backend::{Autodiff, CudaJit}; | ||
|
||
pub fn run() { | ||
launch::<Autodiff<CudaJit>>(Default::default()); | ||
} | ||
} | ||
|
||
fn main() { | ||
#[cfg(any( | ||
feature = "ndarray", | ||
feature = "ndarray-blas-netlib", | ||
feature = "ndarray-blas-openblas", | ||
feature = "ndarray-blas-accelerate", | ||
))] | ||
ndarray::run(); | ||
#[cfg(feature = "tch-gpu")] | ||
tch_gpu::run(); | ||
#[cfg(feature = "tch-cpu")] | ||
tch_cpu::run(); | ||
#[cfg(feature = "wgpu")] | ||
wgpu::run(); | ||
#[cfg(feature = "cuda-jit")] | ||
cuda_jit::run(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
use burn::{optim::RmsPropConfig, tensor::backend::AutodiffBackend}; | ||
|
||
use wgan::{model::ModelConfig, training::TrainingConfig}; | ||
|
||
pub fn launch<B: AutodiffBackend>(device: B::Device) { | ||
let config = TrainingConfig::new( | ||
ModelConfig::new(), | ||
RmsPropConfig::new() | ||
.with_alpha(0.99) | ||
.with_momentum(0.0) | ||
.with_epsilon(0.00000008) | ||
.with_weight_decay(None) | ||
.with_centered(false), | ||
); | ||
|
||
wgan::training::train::<B>("/tmp/wgan-mnist", config, device); | ||
} | ||
|
||
#[cfg(any( | ||
feature = "ndarray", | ||
feature = "ndarray-blas-netlib", | ||
feature = "ndarray-blas-openblas", | ||
feature = "ndarray-blas-accelerate", | ||
))] | ||
mod ndarray { | ||
use burn::backend::{ | ||
ndarray::{NdArray, NdArrayDevice}, | ||
Autodiff, | ||
}; | ||
|
||
use crate::launch; | ||
|
||
pub fn run() { | ||
launch::<Autodiff<NdArray>>(NdArrayDevice::Cpu); | ||
} | ||
} | ||
|
||
#[cfg(feature = "tch-gpu")] | ||
mod tch_gpu { | ||
use burn::backend::{ | ||
libtorch::{LibTorch, LibTorchDevice}, | ||
Autodiff, | ||
}; | ||
|
||
use crate::launch; | ||
|
||
pub fn run() { | ||
#[cfg(not(target_os = "macos"))] | ||
let device = LibTorchDevice::Cuda(0); | ||
#[cfg(target_os = "macos")] | ||
let device = LibTorchDevice::Mps; | ||
|
||
launch::<Autodiff<LibTorch>>(device); | ||
} | ||
} | ||
|
||
#[cfg(feature = "tch-cpu")] | ||
mod tch_cpu { | ||
use burn::backend::{ | ||
libtorch::{LibTorch, LibTorchDevice}, | ||
Autodiff, | ||
}; | ||
|
||
use crate::launch; | ||
|
||
pub fn run() { | ||
launch::<Autodiff<LibTorch>>(LibTorchDevice::Cpu); | ||
} | ||
} | ||
|
||
#[cfg(feature = "wgpu")] | ||
mod wgpu { | ||
use crate::launch; | ||
use burn::backend::{wgpu::Wgpu, Autodiff}; | ||
|
||
pub fn run() { | ||
launch::<Autodiff<Wgpu>>(Default::default()); | ||
} | ||
} | ||
|
||
#[cfg(feature = "cuda-jit")] | ||
mod cuda_jit { | ||
use crate::launch; | ||
use burn::backend::{cuda_jit::CudaDevice, Autodiff, CudaJit}; | ||
|
||
pub fn run() { | ||
launch::<Autodiff<CudaJit>>(CudaDevice::default()); | ||
} | ||
} | ||
|
||
fn main() { | ||
#[cfg(any( | ||
feature = "ndarray", | ||
feature = "ndarray-blas-netlib", | ||
feature = "ndarray-blas-openblas", | ||
feature = "ndarray-blas-accelerate", | ||
))] | ||
ndarray::run(); | ||
#[cfg(feature = "tch-gpu")] | ||
tch_gpu::run(); | ||
#[cfg(feature = "tch-cpu")] | ||
tch_cpu::run(); | ||
#[cfg(feature = "wgpu")] | ||
wgpu::run(); | ||
#[cfg(feature = "cuda-jit")] | ||
cuda_jit::run(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
use burn::{ | ||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, | ||
prelude::*, | ||
}; | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct MnistBatcher<B: Backend> { | ||
device: B::Device, | ||
} | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct MnistBatch<B: Backend> { | ||
pub images: Tensor<B, 4>, | ||
pub targets: Tensor<B, 1, Int>, | ||
} | ||
|
||
impl<B: Backend> MnistBatcher<B> { | ||
pub fn new(device: B::Device) -> Self { | ||
Self { device } | ||
} | ||
} | ||
|
||
impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> { | ||
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> { | ||
let images = items | ||
.iter() | ||
.map(|item| TensorData::from(item.image)) | ||
.map(|data| Tensor::<B, 2>::from_data(data.convert::<B::FloatElem>(), &self.device)) | ||
.map(|tensor| tensor.reshape([1, 28, 28])) | ||
// Set std=0.5 and mean=0.5 to keep consistent with pytorch WGAN example | ||
.map(|tensor| ((tensor / 255) - 0.5) / 0.5) | ||
.collect(); | ||
|
||
let targets = items | ||
.iter() | ||
.map(|item| { | ||
Tensor::<B, 1, Int>::from_data( | ||
TensorData::from([(item.label as i64).elem::<B::IntElem>()]), | ||
&self.device, | ||
) | ||
}) | ||
.collect(); | ||
|
||
let images = Tensor::stack(images, 0); | ||
let targets = Tensor::cat(targets, 0); | ||
|
||
MnistBatch { images, targets } | ||
} | ||
} |
Oops, something went wrong.