Skip to content

Commit

Permalink
Cherrypick allocr related changes from public (#247)
Browse files Browse the repository at this point in the history
* No alloc (#250)

* don't pre-allocate kv cache (it needs reordering anyway)

* enable support for more int operations

* fix buffers allocation

* add kv_cache_ctx for enc_dec attn cache

* add lifespan

* use allocr in generate_sequence

* test all layers with allocr

* avoid copy of wav file

* force allocation of kv_cache otherwise buffers are reused

* get_rows for ints

* ggml: pimp up dot graph

* Revert "add lifespan"

This reverts commit 73cf7963ff9a6dcb37b7713910ba81b797ffb743.

* cleanup

* Revert "ggml: pimp up dot graph"

This reverts commit 6bc467133900e9ba8f5cf48710c9249ea7be8aaf.

* less restrictive test

* allocr for encoder

---------

Co-authored-by: Guillaume Wenzek <[email protected]>
  • Loading branch information
cndn and gwenzek authored Dec 7, 2023
1 parent e1c3ea4 commit 42365df
Show file tree
Hide file tree
Showing 8 changed files with 339 additions and 202 deletions.
4 changes: 2 additions & 2 deletions ggml/Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
build: build/src/libggml.so ggml/build/bin/unity
build: build/examples/unity/libfairseq2_cpp.so ggml/build/bin/unity

build/src/libggml.so: Makefile examples/unity/*.h examples/unity/*.cpp src/ggml*.c
build/examples/unity/libfairseq2_cpp.so: Makefile examples/unity/*.h examples/unity/*.cpp src/ggml*.c
mkdir -p build
cd build; cmake\
-DGGML_OPENBLAS=ON \
Expand Down
276 changes: 163 additions & 113 deletions ggml/examples/unity/fairseq2.cpp

Large diffs are not rendered by default.

22 changes: 16 additions & 6 deletions ggml/examples/unity/fairseq2.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
#include "ggml.h"
#include "kaldi-native-fbank/csrc/feature-fbank.h"

#include "ggml-alloc.h"

#define FORCE_ALLOC(name, ctx, ggml_new_tensor)\
bool name ## _save_no_alloc_ = ggml_get_no_alloc(ctx); \
ggml_set_no_alloc(ctx, false); \
ggml_tensor* name = ggml_new_tensor; \
ggml_set_no_alloc(ctx, name ## _save_no_alloc_);

typedef int32_t llama_token;

extern "C" enum llama_token_type {
Expand Down Expand Up @@ -77,26 +85,28 @@ struct KeyValueTensor {

struct fairseq2_model {
// Context containing all tensors memory
ggml_context* tensors_ctx;
ggml_context* tensors_ctx = nullptr;

// Named tensors, all tensors should belong to tensors_ctx
std::unordered_map<std::string, struct ggml_tensor *> tensors;
std::unordered_map<std::string, struct ggml_tensor *> tensors = {};

// Hashmap containing model hyper-parameters.
std::unordered_map<std::string, std::int64_t> hparams;
std::unordered_map<std::string, std::int64_t> hparams = {};

// Hashmap containing layers hyper-parameters.
// Normally those can be inferred from hparams, but it avoids doing this logic in GGML
std::unordered_map<std::string, std::int64_t> layer_config;
std::unordered_map<std::string, std::int64_t> layer_config = {};

llama_vocab vocab;

// KV cache for attention layers
mutable std::unordered_map<std::string, KeyValueTensor> kv_cache;
mutable std::unordered_map<std::string, KeyValueTensor> kv_cache = {};

// an inference context, not managed by this object
// TODO: is this the best place to store this or should we also pass this to all forward methods ?
ggml_context* ctx;
ggml_context* ctx = nullptr;

ggml_context* kv_cache_ctx = nullptr;
};

double fairseq2_model_layer_config_double(const fairseq2_model& model, std::string name);
Expand Down
29 changes: 17 additions & 12 deletions ggml/examples/unity/unity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <iostream>
#include <sndfile.h>
#include <cstdlib>
#include "ggml-alloc.h"

struct unity_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
Expand Down Expand Up @@ -111,7 +112,7 @@ Hypothesis* unity_decode(
/*eos_idx*/model.vocab.token_to_id["</s>"],
/*num_threads*/n_threads,
};
struct ggml_tensor * prefix_seq = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, 2);
FORCE_ALLOC(prefix_seq, model.ctx, ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, 2));
((int *)prefix_seq->data)[0] = job.eos_idx;
((int *)prefix_seq->data)[1] = tgt_lang_idx;
job.prefix_seq = prefix_seq;
Expand All @@ -133,13 +134,13 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
int ctx_size_gb = 20;
if (model.hparams["w2v2_encoder_config__num_encoder_layers"] == 24) {
ctx_size_gb = 40;
}

// The ctx_size_mb mostly depends of input length and model dim.
int ctx_size_mb = 128;
auto encoder_buf = std::vector<uint8_t>(128 * 1024 * 1024);
auto encoder_fwd_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024);
ggml_allocr* fwd_alloc = ggml_allocr_new(encoder_fwd_buf.data(), encoder_fwd_buf.capacity(), 8);
char result_str[4096];
static std::vector<uint8_t> encoder_buf(ctx_size_gb * 1024LL * 1024LL * 1024LL);

std::string input;
bool interactive = params.files.size() == 0;
Expand Down Expand Up @@ -178,17 +179,21 @@ int main(int argc, char ** argv) {
}
int tgt_lang_idx = tgt_lang_ptr->second;

// Load audio input
std::vector<float> data(info.frames * info.channels); // Assume info.channels is always 1
sf_readf_float(sndfile, data.data(), info.frames);

// Reset the ggml_context
model.ctx = ctx_from_buffer(encoder_buf);
ggml_tensor* seqs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, info.frames, 1);
memcpy(seqs->data, data.data(), data.size() * sizeof(float));
ggml_set_no_alloc(model.ctx, false);
ggml_tensor* seqs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, info.frames, info.channels);
ggml_set_no_alloc(model.ctx, true);

// Load audio input
sf_readf_float(sndfile, (float*)seqs->data, info.frames);

// Audio encoder
ggml_cgraph* gf = unity_speech_encoder(model, seqs);
ggml_allocr_alloc_graph(fwd_alloc, gf);
ggml_graph_compute_with_ctx(model.ctx, gf, params.n_threads);
// encoder_output is valid until we call `ggml_allocr_reset(fwd_alloc)`
ggml_tensor* encoder_output = gf->nodes[gf->n_nodes - 1];

// Beam search decoding
Expand All @@ -201,7 +206,7 @@ int main(int argc, char ** argv) {
int n = fairseq2_spm_detokenize(&model, tokens, (char*)&result_str);
std::cout << std::string((char*)&result_str, n) << std::endl;
ggml_free(model.ctx);

ggml_allocr_reset(fwd_alloc);
}

return 0;
Expand Down
33 changes: 28 additions & 5 deletions ggml/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import numpy as np
import torch
import subprocess
import sys

from ctypes_utils import NULLPTR, Ptr, c_fn, c_struct
from third_party_ggml import *
Expand Down Expand Up @@ -397,10 +399,21 @@ def forward(


def build_and_compute(
ctx: ggml_context_p, tensor: ggml_tensor_p, num_threads: int = 1
) -> None:
ctx: ggml_context_p, tensor: ggml_tensor_p, num_threads: int = 1, dump: Union[bool, str] = False
) -> ggml_cgraph:
gf = ggml_build_forward(tensor)
need_alloc = tensor.contents.data == NULLPTR
if need_alloc:
alloc = FixedSizeArena(1024 * 1024 * 1024 * 2)
ggml_allocr_alloc_graph(alloc.ptr, ctypes.pointer(gf))
setattr(tensor, "__data", alloc)
if dump:
if dump == True:
dump = f"dot/{sys._getframe(1).f_code.co_name}"
ggml_graph_dump_dot(ctypes.pointer(gf), NULLPTR, dump.encode("ascii"))
# subprocess.run(["dot", "-Tsvg", "-O", dump])
ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), num_threads)
return gf


@c_fn(lib)
Expand Down Expand Up @@ -495,7 +508,7 @@ def fairseq2_model_layer_config_int(model: ctypes.c_void_p, name: bytes) -> int:

@c_fn(lib.fairseq2_kv_cache_alloc)
def _fairseq2_kv_cache_alloc(
model: ctypes.c_void_p, beam_size: int, max_seq_len: int
model: ctypes.c_void_p, ctx: ctypes.c_void_p, beam_size: int, max_seq_len: int
) -> None:
pass

Expand All @@ -507,13 +520,23 @@ def _fairseq2_kv_cache_reset(model: ctypes.c_void_p) -> None:

@contextlib.contextmanager
def fairseq2_kv_cache_alloc(
model: ctypes.c_void_p, beam_size: int, max_seq_len: int
model: ctypes.c_void_p, kv_cache_size: int, beam_size: int, max_seq_len: int
) -> Iterator[None]:
_fairseq2_kv_cache_alloc(model, beam_size, max_seq_len)

memory = torch.zeros(kv_cache_size, dtype=torch.uint8)
ctx = ggml_init(
params=ggml_init_params(
mem_size=kv_cache_size,
mem_buffer=ctypes.c_void_p(memory.data_ptr()),
no_alloc=False,
)
)
_fairseq2_kv_cache_alloc(model, ctx, beam_size, max_seq_len)
try:
yield
finally:
_fairseq2_kv_cache_reset(model)
ggml_free(ctx)


@c_fn(lib)
Expand Down
4 changes: 2 additions & 2 deletions ggml/include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this

#define GGML_MAX_DIMS 4
#define GGML_MAX_NODES 8192
#define GGML_MAX_NODES 4096
#define GGML_MAX_PARAMS 256
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 6
Expand Down Expand Up @@ -530,7 +530,7 @@ extern "C" {
// next prime after GGML_MAX_NODES
// #define GGML_GRAPH_HASHTABLE_SIZE 4099
// next prime after GGML_MAX_NODES * 2 (nodes + leafs)
#define GGML_GRAPH_HASHTABLE_SIZE 16411
#define GGML_GRAPH_HASHTABLE_SIZE 8273

// computation graph
struct ggml_cgraph {
Expand Down
10 changes: 7 additions & 3 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -6822,9 +6822,7 @@ struct ggml_tensor * ggml_get_rows(
is_node = true;
}

// TODO: implement non F32 return
//struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);

result->op = GGML_OP_GET_ROWS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
Expand Down Expand Up @@ -8982,10 +8980,12 @@ static void ggml_compute_forward_dup(
}
switch (src0->type) {
case GGML_TYPE_F16:
case GGML_TYPE_I16:
{
ggml_compute_forward_dup_f16(params, src0, dst);
} break;
case GGML_TYPE_F32:
case GGML_TYPE_I32:
{
ggml_compute_forward_dup_f32(params, src0, dst);
} break;
Expand Down Expand Up @@ -10379,6 +10379,7 @@ static void ggml_compute_forward_repeat(
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_I32:
{
ggml_compute_forward_repeat_f32(params, src0, dst);
} break;
Expand Down Expand Up @@ -10520,6 +10521,7 @@ static void ggml_compute_forward_concat(
struct ggml_tensor* dst) {
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_I32:
{
ggml_compute_forward_concat_f32(params, src0, src1, dst);
} break;
Expand Down Expand Up @@ -12284,10 +12286,12 @@ static void ggml_compute_forward_get_rows(
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
} break;
case GGML_TYPE_F16:
case GGML_TYPE_I16:
{
ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
case GGML_TYPE_I32:
{
ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
} break;
Expand Down
Loading

0 comments on commit 42365df

Please sign in to comment.