Skip to content

Commit

Permalink
Integrate CUDA plotter into farmer CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
nazar-pc committed Sep 7, 2024
1 parent d55a43d commit 54943e6
Show file tree
Hide file tree
Showing 2 changed files with 298 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,20 @@ use std::future::Future;
use std::num::NonZeroUsize;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use subspace_core_primitives::crypto::kzg::{embedded_kzg_settings, Kzg};
use subspace_core_primitives::Record;
use subspace_erasure_coding::ErasureCoding;
use subspace_farmer::cluster::controller::ClusterPieceGetter;
use subspace_farmer::cluster::nats_client::NatsClient;
use subspace_farmer::cluster::plotter::plotter_service;
use subspace_farmer::plotter::cpu::CpuPlotter;
#[cfg(feature = "cuda")]
use subspace_farmer::plotter::gpu::cuda::CudaRecordsEncoder;
#[cfg(feature = "_gpu")]
use subspace_farmer::plotter::gpu::GpuPlotter;
use subspace_farmer::plotter::pool::PoolPlotter;
use subspace_farmer::plotter::Plotter;
use subspace_farmer::utils::{
create_plotting_thread_pool_manager, parse_cpu_cores_sets, thread_pool_core_indices,
};
Expand All @@ -23,6 +30,8 @@ use subspace_proof_of_space::Table;
use tokio::sync::Semaphore;
use tracing::info;

const PLOTTING_RETRY_INTERVAL: Duration = Duration::from_secs(5);

#[derive(Debug, Parser)]
struct CpuPlottingOptions {
/// Defines how many sectors farmer will download concurrently, allows to limit memory usage of
Expand All @@ -38,9 +47,9 @@ struct CpuPlottingOptions {
/// `--cpu-sector-downloading-concurrency` and setting this option higher than
/// `--cpu-sector-downloading-concurrency` will have no effect.
///
/// Increase will result in higher memory usage.
/// Increase will result in higher memory usage, setting to 0 will disable CPU plotting.
#[arg(long)]
cpu_sector_encoding_concurrency: Option<NonZeroUsize>,
cpu_sector_encoding_concurrency: Option<usize>,
/// Defines how many records farmer will encode in a single sector concurrently, defaults to one
/// record per 2 cores, but not more than 8 in total. Higher concurrency means higher memory
/// usage and typically more efficient CPU utilization.
Expand Down Expand Up @@ -72,6 +81,24 @@ struct CpuPlottingOptions {
cpu_plotting_thread_priority: PlottingThreadPriority,
}

#[cfg(feature = "cuda")]
#[derive(Debug, Parser)]
struct CudaPlottingOptions {
/// Defines how many sectors farmer will download concurrently during plotting with CUDA GPU,
/// allows to limit memory usage of the plotting process, defaults to number of CUDA GPUs found
/// + 1 to download future sector ahead of time.
///
/// Increase will result in higher memory usage.
#[arg(long)]
cuda_sector_downloading_concurrency: Option<NonZeroUsize>,
/// Specify exact GPUs to be used for plotting instead of using all GPUs (default behavior).
///
/// GPUs are coma-separated: `--cuda-gpus 0,1,3`. Empty string can be specified to disable CUDA
/// GPU usage.
#[arg(long)]
cuda_gpus: Option<String>,
}

/// Arguments for plotter
#[derive(Debug, Parser)]
pub(super) struct PlotterArgs {
Expand All @@ -85,6 +112,10 @@ pub(super) struct PlotterArgs {
/// Plotting options only used by CPU plotter
#[clap(flatten)]
cpu_plotting_options: CpuPlottingOptions,
/// Plotting options only used by CUDA GPU plotter
#[cfg(feature = "cuda")]
#[clap(flatten)]
cuda_plotting_options: CudaPlottingOptions,
/// Additional cluster components
#[clap(raw = true)]
pub(super) additional_components: Vec<String>,
Expand All @@ -102,6 +133,8 @@ where
let PlotterArgs {
piece_getter_concurrency,
cpu_plotting_options,
#[cfg(feature = "cuda")]
cuda_plotting_options,
additional_components: _,
} = plotter_args;

Expand All @@ -115,37 +148,63 @@ where

let global_mutex = Arc::default();

let (legacy_cpu_plotter, modern_cpu_plotter) = init_cpu_plotters::<_, PosTableLegacy, PosTable>(
cpu_plotting_options,
piece_getter,
global_mutex,
kzg,
erasure_coding,
registry,
)?;
let legacy_cpu_plotter = Arc::new(legacy_cpu_plotter);
let modern_cpu_plotter = Arc::new(modern_cpu_plotter);
let mut legacy_plotters = Vec::<Box<dyn Plotter + Send + Sync>>::new();
let mut modern_plotters = Vec::<Box<dyn Plotter + Send + Sync>>::new();

{
let maybe_cpu_plotters = init_cpu_plotters::<_, PosTableLegacy, PosTable>(
cpu_plotting_options,
piece_getter.clone(),
Arc::clone(&global_mutex),
kzg.clone(),
erasure_coding.clone(),
registry,
)?;

if let Some((legacy_cpu_plotter, modern_cpu_plotter)) = maybe_cpu_plotters {
legacy_plotters.push(Box::new(legacy_cpu_plotter));
modern_plotters.push(Box::new(modern_cpu_plotter));
}
}
#[cfg(feature = "cuda")]
{
let maybe_cuda_plotter = init_cuda_plotter(
cuda_plotting_options,
piece_getter,
global_mutex,
kzg,
erasure_coding,
registry,
)?;

if let Some(cuda_plotter) = maybe_cuda_plotter {
modern_plotters.push(Box::new(cuda_plotter));
}
}
let legacy_plotter = Arc::new(PoolPlotter::new(legacy_plotters, PLOTTING_RETRY_INTERVAL));
let modern_plotter = Arc::new(PoolPlotter::new(modern_plotters, PLOTTING_RETRY_INTERVAL));

Ok(Box::pin(async move {
select! {
result = plotter_service(&nats_client, &legacy_cpu_plotter, false).fuse() => {
result = plotter_service(&nats_client, &legacy_plotter, false).fuse() => {
result.map_err(|error| anyhow!("Plotter service failed: {error}"))
}
result = plotter_service(&nats_client, &modern_cpu_plotter, true).fuse() => {
result = plotter_service(&nats_client, &modern_plotter, true).fuse() => {
result.map_err(|error| anyhow!("Plotter service failed: {error}"))
}
}
}))
}

#[allow(clippy::type_complexity)]
fn init_cpu_plotters<PG, PosTableLegacy, PosTable>(
cpu_plotting_options: CpuPlottingOptions,
piece_getter: PG,
global_mutex: Arc<AsyncMutex<()>>,
kzg: Kzg,
erasure_coding: ErasureCoding,
registry: &mut Registry,
) -> anyhow::Result<(CpuPlotter<PG, PosTableLegacy>, CpuPlotter<PG, PosTable>)>
) -> anyhow::Result<Option<(CpuPlotter<PG, PosTableLegacy>, CpuPlotter<PG, PosTable>)>>
where
PG: PieceGetter + Clone + Send + Sync + 'static,
PosTableLegacy: Table,
Expand All @@ -160,6 +219,19 @@ where
cpu_plotting_thread_priority,
} = cpu_plotting_options;

let cpu_sector_encoding_concurrency =
if let Some(cpu_sector_encoding_concurrency) = cpu_sector_encoding_concurrency {
match NonZeroUsize::new(cpu_sector_encoding_concurrency) {
Some(cpu_sector_encoding_concurrency) => Some(cpu_sector_encoding_concurrency),
None => {
info!("CPU plotting was explicitly disabled");
return Ok(None);
}
}
} else {
None
};

let plotting_thread_pool_core_indices;
if let Some(cpu_plotting_cores) = cpu_plotting_cores {
plotting_thread_pool_core_indices = parse_cpu_cores_sets(&cpu_plotting_cores)
Expand Down Expand Up @@ -235,5 +307,67 @@ where
Some(registry),
);

Ok((legacy_cpu_plotter, modern_cpu_plotter))
Ok(Some((legacy_cpu_plotter, modern_cpu_plotter)))
}

#[cfg(feature = "cuda")]
fn init_cuda_plotter<PG>(
cuda_plotting_options: CudaPlottingOptions,
piece_getter: PG,
global_mutex: Arc<AsyncMutex<()>>,
kzg: Kzg,
erasure_coding: ErasureCoding,
registry: &mut Registry,
) -> anyhow::Result<Option<GpuPlotter<PG, CudaRecordsEncoder>>>
where
PG: PieceGetter + Clone + Send + Sync + 'static,
{
use std::collections::HashSet;
use subspace_proof_of_space_gpu::cuda::cuda_devices;

let CudaPlottingOptions {
cuda_sector_downloading_concurrency,
cuda_gpus,
} = cuda_plotting_options;

let mut cuda_devices = cuda_devices();

if let Some(cuda_gpus) = cuda_gpus {
if cuda_gpus.is_empty() {
return Ok(None);
}

let cuda_gpus = cuda_gpus
.split(',')
.map(|gpu_index| gpu_index.parse())
.collect::<Result<HashSet<usize>, _>>()?;

cuda_devices = cuda_devices
.into_iter()
.enumerate()
.filter_map(|(index, cuda_device)| cuda_gpus.contains(&index).then_some(cuda_device))
.collect();
}

let cuda_downloading_semaphore = Arc::new(Semaphore::new(
cuda_sector_downloading_concurrency
.map(|cuda_sector_downloading_concurrency| cuda_sector_downloading_concurrency.get())
.unwrap_or(cuda_devices.len() + 1),
));

Ok(Some(
GpuPlotter::new(
piece_getter,
cuda_downloading_semaphore,
cuda_devices
.into_iter()
.map(|cuda_device| CudaRecordsEncoder::new(cuda_device, Arc::clone(&global_mutex)))
.collect(),
global_mutex,
kzg,
erasure_coding,
Some(registry),
)
.map_err(|error| anyhow::anyhow!("Failed to initialize CUDA plotter: {error}"))?,
))
}
Loading

0 comments on commit 54943e6

Please sign in to comment.