Skip to content

Commit

Permalink
Initial GPU plotter implementation (not integrated into the farmer yet)
Browse files Browse the repository at this point in the history
  • Loading branch information
nazar-pc committed Sep 4, 2024
1 parent 92383b3 commit 24da28b
Show file tree
Hide file tree
Showing 8 changed files with 345 additions and 143 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions crates/subspace-farmer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ subspace-core-primitives = { version = "0.1.0", path = "../subspace-core-primiti
subspace-metrics = { version = "0.1.0", path = "../../shared/subspace-metrics", optional = true }
subspace-networking = { version = "0.1.0", path = "../subspace-networking" }
subspace-proof-of-space = { version = "0.1.0", path = "../subspace-proof-of-space" }
subspace-proof-of-space-gpu = { version = "0.1.0", path = "../../shared/subspace-proof-of-space-gpu", optional = true }
subspace-rpc-primitives = { version = "0.1.0", path = "../subspace-rpc-primitives" }
substrate-bip39 = "0.6.0"
supports-color = { version = "3.0.0", optional = true }
Expand All @@ -75,6 +76,10 @@ zeroize = "1.8.1"
default = ["default-library", "binary"]
cluster = ["dep:async-nats"]
numa = ["dep:hwlocality"]
# Only Volta+ architectures are supported (GeForce RTX 20xx consumer GPUs and newer)
cuda = ["_gpu", "subspace-proof-of-space-gpu/cuda"]
# Internal feature, shouldn't be used directly
_gpu = []

# TODO: This is a hack for https://github.com/rust-lang/cargo/issues/1982, `default-library` is what would essentially
# be default, but because binary compilation will require explicit feature to be specified without `binary` feature we
Expand Down
2 changes: 2 additions & 0 deletions crates/subspace-farmer/src/plotter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
//! implementations without the rest of the library being aware of implementation details.
pub mod cpu;
#[cfg(feature = "_gpu")]
pub mod gpu;
pub mod pool;

use async_trait::async_trait;
Expand Down
1 change: 1 addition & 0 deletions crates/subspace-farmer/src/plotter/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ where
PosTable: Table,
{
async fn has_free_capacity(&self) -> Result<bool, String> {
// TODO: Check available thread pools
Ok(self.downloading_semaphore.available_permits() > 0)
}

Expand Down
137 changes: 63 additions & 74 deletions crates/subspace-farmer/src/plotter/gpu.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
//! CPU plotter
//! GPU plotter
#[cfg(feature = "cuda")]
pub mod cuda;
mod gpu_encoders_manager;
pub mod metrics;

use crate::plotter::cpu::metrics::CpuPlotterMetrics;
use crate::plotter::gpu::gpu_encoders_manager::GpuRecordsEncoderManager;
use crate::plotter::gpu::metrics::GpuPlotterMetrics;
use crate::plotter::{Plotter, SectorPlottingProgress};
use crate::thread_pool_manager::PlottingThreadPoolManager;
use crate::utils::AsyncJoinOnDrop;
use async_lock::Mutex as AsyncMutex;
use async_trait::async_trait;
Expand All @@ -13,12 +16,10 @@ use futures::channel::mpsc;
use futures::stream::FuturesUnordered;
use futures::{select, stream, FutureExt, Sink, SinkExt, StreamExt};
use prometheus_client::registry::Registry;
use std::any::type_name;
use std::error::Error;
use std::fmt;
use std::future::pending;
use std::marker::PhantomData;
use std::num::NonZeroUsize;
use std::num::TryFromIntError;
use std::pin::pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
Expand All @@ -27,11 +28,10 @@ use subspace_core_primitives::crypto::kzg::Kzg;
use subspace_core_primitives::{PublicKey, SectorIndex};
use subspace_erasure_coding::ErasureCoding;
use subspace_farmer_components::plotting::{
download_sector, encode_sector, CpuRecordsEncoder, DownloadSectorOptions, EncodeSectorOptions,
PlottingError,
download_sector, encode_sector, DownloadSectorOptions, EncodeSectorOptions, PlottingError,
RecordsEncoder,
};
use subspace_farmer_components::{FarmerProtocolInfo, PieceGetter};
use subspace_proof_of_space::Table;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio::task::yield_now;
use tracing::{warn, Instrument};
Expand All @@ -45,31 +45,39 @@ struct Handlers {
plotting_progress: Handler3<PublicKey, SectorIndex, SectorPlottingProgress>,
}

/// CPU plotter
pub struct CpuPlotter<PG, PosTable> {
/// GPU-specific [`RecordsEncoder`] with extra APIs
pub trait GpuRecordsEncoder: RecordsEncoder + fmt::Debug + Send {
/// GPU encoder type, typically related to GPU vendor
const TYPE: &'static str;
}

/// GPU plotter
pub struct GpuPlotter<PG, GRE> {
piece_getter: PG,
downloading_semaphore: Arc<Semaphore>,
plotting_thread_pool_manager: PlottingThreadPoolManager,
record_encoding_concurrency: NonZeroUsize,
gpu_records_encoders_manager: GpuRecordsEncoderManager<GRE>,
global_mutex: Arc<AsyncMutex<()>>,
kzg: Kzg,
erasure_coding: ErasureCoding,
handlers: Arc<Handlers>,
tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
_background_tasks: AsyncJoinOnDrop<()>,
abort_early: Arc<AtomicBool>,
metrics: Option<Arc<CpuPlotterMetrics>>,
_phantom: PhantomData<PosTable>,
metrics: Option<Arc<GpuPlotterMetrics>>,
}

impl<PG, PosTable> fmt::Debug for CpuPlotter<PG, PosTable> {
impl<PG, GRE> fmt::Debug for GpuPlotter<PG, GRE>
where
GRE: GpuRecordsEncoder + 'static,
{
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CpuPlotter").finish_non_exhaustive()
f.debug_struct(&format!("GpuPlotter[type = {}]", GRE::TYPE))
.finish_non_exhaustive()
}
}

impl<PG, PosTable> Drop for CpuPlotter<PG, PosTable> {
impl<PG, RE> Drop for GpuPlotter<PG, RE> {
#[inline]
fn drop(&mut self) {
self.abort_early.store(true, Ordering::Release);
Expand All @@ -78,12 +86,13 @@ impl<PG, PosTable> Drop for CpuPlotter<PG, PosTable> {
}

#[async_trait]
impl<PG, PosTable> Plotter for CpuPlotter<PG, PosTable>
impl<PG, GRE> Plotter for GpuPlotter<PG, GRE>
where
PG: PieceGetter + Clone + Send + Sync + 'static,
PosTable: Table,
GRE: GpuRecordsEncoder + 'static,
{
async fn has_free_capacity(&self) -> Result<bool, String> {
// TODO: Check available GPU encoders
Ok(self.downloading_semaphore.available_permits() > 0)
}

Expand All @@ -93,7 +102,7 @@ where
sector_index: SectorIndex,
farmer_protocol_info: FarmerProtocolInfo,
pieces_in_sector: u16,
replotting: bool,
_replotting: bool,
mut progress_sender: mpsc::Sender<SectorPlottingProgress>,
) {
let start = Instant::now();
Expand Down Expand Up @@ -135,7 +144,6 @@ where
sector_index,
farmer_protocol_info,
pieces_in_sector,
replotting,
progress_sender,
)
.await
Expand All @@ -147,7 +155,7 @@ where
sector_index: SectorIndex,
farmer_protocol_info: FarmerProtocolInfo,
pieces_in_sector: u16,
replotting: bool,
_replotting: bool,
progress_sender: mpsc::Sender<SectorPlottingProgress>,
) -> bool {
let start = Instant::now();
Expand All @@ -164,7 +172,6 @@ where
sector_index,
farmer_protocol_info,
pieces_in_sector,
replotting,
progress_sender,
)
.await;
Expand All @@ -173,23 +180,23 @@ where
}
}

impl<PG, PosTable> CpuPlotter<PG, PosTable>
impl<PG, GRE> GpuPlotter<PG, GRE>
where
PG: PieceGetter + Clone + Send + Sync + 'static,
PosTable: Table,
GRE: GpuRecordsEncoder + 'static,
{
/// Create new instance
#[allow(clippy::too_many_arguments)]
/// Create new instance.
///
/// Returns an error if empty list of encoders is provided.
pub fn new(
piece_getter: PG,
downloading_semaphore: Arc<Semaphore>,
plotting_thread_pool_manager: PlottingThreadPoolManager,
record_encoding_concurrency: NonZeroUsize,
gpu_records_encoders: Vec<GRE>,
global_mutex: Arc<AsyncMutex<()>>,
kzg: Kzg,
erasure_coding: ErasureCoding,
registry: Option<&mut Registry>,
) -> Self {
) -> Result<Self, TryFromIntError> {
let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);

// Basically runs plotting tasks in the background and allows to abort on drop
Expand Down Expand Up @@ -219,19 +226,19 @@ where
);

let abort_early = Arc::new(AtomicBool::new(false));
let gpu_records_encoders_manager = GpuRecordsEncoderManager::new(gpu_records_encoders)?;
let metrics = registry.map(|registry| {
Arc::new(CpuPlotterMetrics::new(
Arc::new(GpuPlotterMetrics::new(
registry,
type_name::<PosTable>(),
plotting_thread_pool_manager.thread_pool_pairs(),
GRE::TYPE,
gpu_records_encoders_manager.gpu_records_encoders(),
))
});

Self {
Ok(Self {
piece_getter,
downloading_semaphore,
plotting_thread_pool_manager,
record_encoding_concurrency,
gpu_records_encoders_manager,
global_mutex,
kzg,
erasure_coding,
Expand All @@ -240,8 +247,7 @@ where
_background_tasks: background_tasks,
abort_early,
metrics,
_phantom: PhantomData,
}
})
}

/// Subscribe to plotting progress notifications
Expand All @@ -261,7 +267,6 @@ where
sector_index: SectorIndex,
farmer_protocol_info: FarmerProtocolInfo,
pieces_in_sector: u16,
replotting: bool,
mut progress_sender: PS,
) where
PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
Expand All @@ -280,8 +285,7 @@ where

let plotting_fut = {
let piece_getter = self.piece_getter.clone();
let plotting_thread_pool_manager = self.plotting_thread_pool_manager.clone();
let record_encoding_concurrency = self.record_encoding_concurrency;
let gpu_records_encoders_manager = self.gpu_records_encoders_manager.clone();
let global_mutex = Arc::clone(&self.global_mutex);
let kzg = self.kzg.clone();
let erasure_coding = self.erasure_coding.clone();
Expand Down Expand Up @@ -349,17 +353,11 @@ where

// Plotting
let (sector, plotted_sector) = {
let thread_pools = plotting_thread_pool_manager.get_thread_pools().await;
let mut records_encoder = gpu_records_encoders_manager.get_encoder().await;
if let Some(metrics) = &metrics {
metrics.plotting_capacity_used.inc();
}

let thread_pool = if replotting {
&thread_pools.replotting
} else {
&thread_pools.plotting
};

// Give a chance to interrupt plotting if necessary
yield_now().await;

Expand All @@ -379,31 +377,22 @@ where
let encoding_start = Instant::now();

let plotting_result = tokio::task::block_in_place(|| {
thread_pool.install(|| {
let mut sector = Vec::new();
let mut generators = (0..record_encoding_concurrency.get())
.map(|_| PosTable::generator())
.collect::<Vec<_>>();
let mut records_encoder = CpuRecordsEncoder::<PosTable>::new(
&mut generators,
&erasure_coding,
&global_mutex,
);

let plotted_sector = encode_sector(
downloaded_sector,
EncodeSectorOptions {
sector_index,
sector_output: &mut sector,
records_encoder: &mut records_encoder,
abort_early: &abort_early,
},
)?;

Ok((sector, plotted_sector))
})
let mut sector = Vec::new();

let plotted_sector = encode_sector(
downloaded_sector,
EncodeSectorOptions {
sector_index,
sector_output: &mut sector,
records_encoder: &mut *records_encoder,
abort_early: &abort_early,
},
)?;

Ok((sector, plotted_sector))
});
drop(thread_pools);
drop(records_encoder);

if let Some(metrics) = &metrics {
metrics.plotting_capacity_used.dec();
}
Expand Down Expand Up @@ -476,7 +465,7 @@ struct ProgressUpdater {
public_key: PublicKey,
sector_index: SectorIndex,
handlers: Arc<Handlers>,
metrics: Option<Arc<CpuPlotterMetrics>>,
metrics: Option<Arc<GpuPlotterMetrics>>,
}

impl ProgressUpdater {
Expand Down
Loading

0 comments on commit 24da28b

Please sign in to comment.