diff --git a/go.mod b/go.mod index ac147994..5bf3f9e2 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,9 @@ require ( ) require ( + github.com/btcsuite/btcd/btcec/v2 v2.3.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect diff --git a/go.sum b/go.sum index d2432cb4..02c61fd2 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,12 @@ +github.com/btcsuite/btcd/btcec/v2 v2.3.4 h1:3EJjcN70HCu/mwqlUsGK8GcNVyLVxFDlWurTXGPFfiQ= +github.com/btcsuite/btcd/btcec/v2 v2.3.4/go.mod h1:zYzJ8etWJQIv1Ogk7OzpWjowwOdXY1W/17j2MW85J04= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs= github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/kilic/bls12-381 v0.1.0 h1:encrdjqKMEvabVQ7qYOKu1OvhqpK4s47wDYtNiPtlp4= diff --git a/internal/runtime/cache/cache.go b/internal/runtime/cache/cache.go new file mode 100644 index 00000000..f852ef57 --- /dev/null +++ b/internal/runtime/cache/cache.go @@ -0,0 +1,100 @@ +package cache + +import ( + "crypto/sha256" + "sync" + + "github.com/tetratelabs/wazero" +) + +// Cache manages compiled Wasm modules +type Cache struct { + mu sync.RWMutex + codeCache map[string][]byte // stores raw Wasm bytes + compiledModules map[string]wazero.CompiledModule + pinnedModules map[string]struct{} + moduleHits map[string]uint32 + moduleSizes map[string]uint64 +} + +// New creates a new cache instance +func New() *Cache { + return &Cache{ + codeCache: make(map[string][]byte), + compiledModules: make(map[string]wazero.CompiledModule), + pinnedModules: make(map[string]struct{}), + moduleHits: make(map[string]uint32), + moduleSizes: make(map[string]uint64), + } +} + +// Save stores a Wasm module in the cache +func (c *Cache) Save(code []byte) []byte { + checksum := sha256.Sum256(code) + key := string(checksum[:]) + + c.mu.Lock() + defer c.mu.Unlock() + + c.codeCache[key] = code + c.moduleSizes[key] = uint64(len(code)) + return checksum[:] +} + +// Load retrieves a Wasm module from the cache +func (c *Cache) Load(checksum []byte) ([]byte, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + code, exists := c.codeCache[string(checksum)] + if exists { + c.moduleHits[string(checksum)]++ + } + return code, exists +} + +// Pin marks a module as pinned in memory +func (c *Cache) Pin(checksum []byte) { + c.mu.Lock() + defer c.mu.Unlock() + c.pinnedModules[string(checksum)] = struct{}{} +} + +// Unpin removes the pin from a module +func (c *Cache) Unpin(checksum []byte) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.pinnedModules, string(checksum)) +} + +// Remove deletes a module from the cache if it's not pinned +func (c *Cache) Remove(checksum []byte) bool { + c.mu.Lock() + defer c.mu.Unlock() + + key := string(checksum) + if _, isPinned := c.pinnedModules[key]; isPinned { + return false + } + + delete(c.codeCache, key) + delete(c.compiledModules, key) + delete(c.moduleHits, key) + delete(c.moduleSizes, key) + return true +} + +// SaveCompiledModule stores a compiled module in the cache +func (c *Cache) SaveCompiledModule(checksum []byte, module wazero.CompiledModule) { + c.mu.Lock() + defer c.mu.Unlock() + c.compiledModules[string(checksum)] = module +} + +// LoadCompiledModule retrieves a compiled module from the cache +func (c *Cache) LoadCompiledModule(checksum []byte) (wazero.CompiledModule, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + module, exists := c.compiledModules[string(checksum)] + return module, exists +} diff --git a/internal/runtime/crypto.go b/internal/runtime/crypto.go deleted file mode 100644 index 359fa2cc..00000000 --- a/internal/runtime/crypto.go +++ /dev/null @@ -1,152 +0,0 @@ -package runtime - -import ( - "crypto/ecdh" - "crypto/ecdsa" - "crypto/elliptic" - "errors" - "fmt" - "math/big" - - bls12381 "github.com/kilic/bls12-381" -) - -// BLS12381AggregateG1 aggregates multiple G1 points into a single compressed G1 point. -func BLS12381AggregateG1(elements [][]byte) ([]byte, error) { - if len(elements) == 0 { - return nil, fmt.Errorf("no elements to aggregate") - } - - g1 := bls12381.NewG1() - result := g1.Zero() - - for _, element := range elements { - point, err := g1.FromCompressed(element) - if err != nil { - return nil, fmt.Errorf("failed to decompress G1 point: %w", err) - } - g1.Add(result, result, point) - } - - return g1.ToCompressed(result), nil -} - -// BLS12381AggregateG2 aggregates multiple G2 points into a single compressed G2 point. -func BLS12381AggregateG2(elements [][]byte) ([]byte, error) { - if len(elements) == 0 { - return nil, fmt.Errorf("no elements to aggregate") - } - - g2 := bls12381.NewG2() - result := g2.Zero() - - for _, element := range elements { - point, err := g2.FromCompressed(element) - if err != nil { - return nil, fmt.Errorf("failed to decompress G2 point: %w", err) - } - g2.Add(result, result, point) - } - - return g2.ToCompressed(result), nil -} - -// BLS12381HashToG1 hashes arbitrary bytes to a compressed G1 point. -func BLS12381HashToG1(message []byte) ([]byte, error) { - g1 := bls12381.NewG1() - // You can choose a domain separation string of your liking. - // Here, we use a placeholder domain: "BLS12381G1_XMD:SHA-256_SSWU_RO_" - point, err := g1.HashToCurve(message, []byte("BLS12381G1_XMD:SHA-256_SSWU_RO_")) - if err != nil { - return nil, fmt.Errorf("failed to hash to G1: %w", err) - } - return g1.ToCompressed(point), nil -} - -// BLS12381HashToG2 hashes arbitrary bytes to a compressed G2 point. -func BLS12381HashToG2(message []byte) ([]byte, error) { - g2 := bls12381.NewG2() - // Similar domain separation string for G2. - point, err := g2.HashToCurve(message, []byte("BLS12381G2_XMD:SHA-256_SSWU_RO_")) - if err != nil { - return nil, fmt.Errorf("failed to hash to G2: %w", err) - } - return g2.ToCompressed(point), nil -} - -// BLS12381PairingEquality checks if e(a1, a2) == e(b1, b2) in the BLS12-381 pairing. -func BLS12381PairingEquality(a1Compressed, a2Compressed, b1Compressed, b2Compressed []byte) (bool, error) { - g1 := bls12381.NewG1() - g2 := bls12381.NewG2() - - a1, err := g1.FromCompressed(a1Compressed) - if err != nil { - return false, fmt.Errorf("failed to decompress a1: %w", err) - } - a2, err := g2.FromCompressed(a2Compressed) - if err != nil { - return false, fmt.Errorf("failed to decompress a2: %w", err) - } - b1, err := g1.FromCompressed(b1Compressed) - if err != nil { - return false, fmt.Errorf("failed to decompress b1: %w", err) - } - b2, err := g2.FromCompressed(b2Compressed) - if err != nil { - return false, fmt.Errorf("failed to decompress b2: %w", err) - } - - engine := bls12381.NewEngine() - // AddPair computes pairing e(a1, a2). - engine.AddPair(a1, a2) - // AddPairInv computes pairing e(b1, b2)^(-1), so effectively we check e(a1,a2) * e(b1,b2)^(-1) == 1. - engine.AddPairInv(b1, b2) - - ok := engine.Check() - return ok, nil -} - -// Secp256r1Verify verifies a P-256 ECDSA signature. -// hash is the message digest (NOT the preimage), -// signature should be 64 bytes (r and s concatenated), -// pubkey should be an uncompressed or compressed public key in standard format. -func Secp256r1Verify(hash, signature, pubkey []byte) (bool, error) { - // Parse public key using crypto/ecdh - curve := ecdh.P256() - key, err := curve.NewPublicKey(pubkey) - if err != nil { - return false, fmt.Errorf("invalid public key: %w", err) - } - - // Get the raw coordinates for ECDSA verification - rawKey := key.Bytes() - x, y := elliptic.UnmarshalCompressed(elliptic.P256(), rawKey) - if x == nil { - return false, errors.New("failed to parse public key coordinates") - } - - // Parse signature: must be exactly 64 bytes => r (first 32 bytes), s (second 32 bytes). - if len(signature) != 64 { - return false, fmt.Errorf("signature must be 64 bytes, got %d", len(signature)) - } - r := new(big.Int).SetBytes(signature[:32]) - s := new(big.Int).SetBytes(signature[32:64]) - - pub := &ecdsa.PublicKey{ - Curve: elliptic.P256(), - X: x, - Y: y, - } - - verified := ecdsa.Verify(pub, hash, r, s) - return verified, nil -} - -// Secp256r1RecoverPubkey tries to recover a P-256 public key from a signature. -// In general, ECDSA on P-256 is not commonly used with "public key recovery" like secp256k1. -// This is non-standard and provided here as a placeholder or with specialized tooling only. -func Secp256r1RecoverPubkey(hash, signature []byte, recovery byte) ([]byte, error) { - // ECDSA on secp256r1 (P-256) does not support public key recovery in the standard library. - // Typically one would need a specialized library. This stub is included for completeness. - return nil, fmt.Errorf("public key recovery is not standard for secp256r1") -} diff --git a/internal/runtime/crypto/bls.go b/internal/runtime/crypto/bls.go new file mode 100644 index 00000000..e52e12f3 --- /dev/null +++ b/internal/runtime/crypto/bls.go @@ -0,0 +1,100 @@ +package crypto + +import ( + "fmt" + + bls12381 "github.com/kilic/bls12-381" +) + +// BLS12381AggregateG1 aggregates multiple G1 points into a single compressed G1 point. +func (v *Verifier) BLS12381AggregateG1(elements [][]byte) ([]byte, error) { + if len(elements) == 0 { + return nil, fmt.Errorf("no elements to aggregate") + } + + g1 := bls12381.NewG1() + result := g1.Zero() + + for _, element := range elements { + point, err := g1.FromCompressed(element) + if err != nil { + return nil, fmt.Errorf("failed to decompress G1 point: %w", err) + } + g1.Add(result, result, point) + } + + return g1.ToCompressed(result), nil +} + +// BLS12381AggregateG2 aggregates multiple G2 points into a single compressed G2 point. +func (v *Verifier) BLS12381AggregateG2(elements [][]byte) ([]byte, error) { + if len(elements) == 0 { + return nil, fmt.Errorf("no elements to aggregate") + } + + g2 := bls12381.NewG2() + result := g2.Zero() + + for _, element := range elements { + point, err := g2.FromCompressed(element) + if err != nil { + return nil, fmt.Errorf("failed to decompress G2 point: %w", err) + } + g2.Add(result, result, point) + } + + return g2.ToCompressed(result), nil +} + +// BLS12381HashToG1 hashes a message to a G1 point. +func (v *Verifier) BLS12381HashToG1(message []byte) ([]byte, error) { + g1 := bls12381.NewG1() + domain := []byte("BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_NUL_") + point, err := g1.HashToCurve(message, domain) + if err != nil { + return nil, fmt.Errorf("failed to hash to G1: %w", err) + } + return g1.ToCompressed(point), nil +} + +// BLS12381HashToG2 hashes a message to a G2 point. +func (v *Verifier) BLS12381HashToG2(message []byte) ([]byte, error) { + g2 := bls12381.NewG2() + domain := []byte("BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_NUL_") + point, err := g2.HashToCurve(message, domain) + if err != nil { + return nil, fmt.Errorf("failed to hash to G2: %w", err) + } + return g2.ToCompressed(point), nil +} + +// BLS12381PairingEquality checks if e(a1, a2) = e(b1, b2). +func (v *Verifier) BLS12381PairingEquality(a1Compressed, a2Compressed, b1Compressed, b2Compressed []byte) (bool, error) { + engine := bls12381.NewEngine() + g1 := bls12381.NewG1() + g2 := bls12381.NewG2() + + a1, err := g1.FromCompressed(a1Compressed) + if err != nil { + return false, fmt.Errorf("failed to decompress a1: %w", err) + } + + a2, err := g2.FromCompressed(a2Compressed) + if err != nil { + return false, fmt.Errorf("failed to decompress a2: %w", err) + } + + b1, err := g1.FromCompressed(b1Compressed) + if err != nil { + return false, fmt.Errorf("failed to decompress b1: %w", err) + } + + b2, err := g2.FromCompressed(b2Compressed) + if err != nil { + return false, fmt.Errorf("failed to decompress b2: %w", err) + } + + engine.AddPair(a1, a2) + engine.AddPairInv(b1, b2) + return engine.Check(), nil +} diff --git a/internal/runtime/crypto/crypto.go b/internal/runtime/crypto/crypto.go new file mode 100644 index 00000000..41ef34e4 --- /dev/null +++ b/internal/runtime/crypto/crypto.go @@ -0,0 +1,90 @@ +package crypto + +import ( + "crypto/ed25519" + "crypto/sha256" + "crypto/sha512" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" +) + +// Verifier handles cryptographic verification operations +type Verifier struct{} + +// New creates a new crypto verifier +func New() *Verifier { + return &Verifier{} +} + +// VerifySecp256k1Signature verifies a secp256k1 signature +func (v *Verifier) VerifySecp256k1Signature(hash, signature, pubkey []byte) (bool, error) { + if len(hash) != sha256.Size { + return false, ErrInvalidHashFormat + } + if len(signature) != 64 { + return false, ErrInvalidSignatureFormat + } + if len(pubkey) != 33 && len(pubkey) != 65 { + return false, ErrInvalidPubkeyFormat + } + + // Parse public key + pk, err := btcec.ParsePubKey(pubkey) + if err != nil { + return false, err + } + + // Parse signature + r := new(btcec.ModNScalar) + s := new(btcec.ModNScalar) + if !r.SetByteSlice(signature[:32]) || !s.SetByteSlice(signature[32:]) { + return false, ErrInvalidSignatureFormat + } + sig := ecdsa.NewSignature(r, s) + + return sig.Verify(hash, pk), nil +} + +// VerifyEd25519Signature verifies an ed25519 signature +func (v *Verifier) VerifyEd25519Signature(message, signature, pubkey []byte) (bool, error) { + if len(signature) != ed25519.SignatureSize { + return false, ErrInvalidSignatureFormat + } + if len(pubkey) != ed25519.PublicKeySize { + return false, ErrInvalidPubkeyFormat + } + + return ed25519.Verify(ed25519.PublicKey(pubkey), message, signature), nil +} + +// VerifyEd25519Signatures verifies multiple ed25519 signatures in batch +func (v *Verifier) VerifyEd25519Signatures(messages [][]byte, signatures [][]byte, pubkeys [][]byte) (bool, error) { + if len(messages) != len(signatures) || len(signatures) != len(pubkeys) { + return false, ErrInvalidBatchFormat + } + + for i := range messages { + ok, err := v.VerifyEd25519Signature(messages[i], signatures[i], pubkeys[i]) + if err != nil { + return false, err + } + if !ok { + return false, nil + } + } + + return true, nil +} + +// SHA256 computes the SHA256 hash of data +func (v *Verifier) SHA256(data []byte) []byte { + sum := sha256.Sum256(data) + return sum[:] +} + +// SHA512 computes the SHA512 hash of data +func (v *Verifier) SHA512(data []byte) []byte { + sum := sha512.Sum512(data) + return sum[:] +} diff --git a/internal/runtime/crypto/errors.go b/internal/runtime/crypto/errors.go new file mode 100644 index 00000000..dfac758b --- /dev/null +++ b/internal/runtime/crypto/errors.go @@ -0,0 +1,14 @@ +package crypto + +import "errors" + +var ( + // ErrInvalidHashFormat is returned when a hash has an invalid format + ErrInvalidHashFormat = errors.New("invalid hash format") + // ErrInvalidSignatureFormat is returned when a signature has an invalid format + ErrInvalidSignatureFormat = errors.New("invalid signature format") + // ErrInvalidPubkeyFormat is returned when a public key has an invalid format + ErrInvalidPubkeyFormat = errors.New("invalid public key format") + // ErrInvalidBatchFormat is returned when batch verification inputs have mismatched lengths + ErrInvalidBatchFormat = errors.New("invalid batch format: mismatched input lengths") +) diff --git a/internal/runtime/crypto/host.go b/internal/runtime/crypto/host.go new file mode 100644 index 00000000..836e49bb --- /dev/null +++ b/internal/runtime/crypto/host.go @@ -0,0 +1,149 @@ +package crypto + +import ( + "fmt" + + "github.com/CosmWasm/wasmvm/v2/internal/runtime/memory" +) + +// HostFunctions provides cryptographic host functions +type HostFunctions struct { + verifier *Verifier + mem *memory.HostFunctions + allocator *memory.Allocator +} + +// NewHostFunctions creates a new set of crypto host functions +func NewHostFunctions(mem *memory.HostFunctions, allocator *memory.Allocator) *HostFunctions { + return &HostFunctions{ + verifier: New(), + mem: mem, + allocator: allocator, + } +} + +// VerifySecp256k1 verifies a secp256k1 signature +func (h *HostFunctions) VerifySecp256k1(hashPtr, sigPtr, pubkeyPtr uint32) (uint32, error) { + // Read regions + hashRegion, err := h.mem.ReadRegion(hashPtr) + if err != nil { + return 0, err + } + + sigRegion, err := h.mem.ReadRegion(sigPtr) + if err != nil { + return 0, err + } + + pubkeyRegion, err := h.mem.ReadRegion(pubkeyPtr) + if err != nil { + return 0, err + } + + // Read data from memory + hash, err := h.mem.ReadString(hashRegion) + if err != nil { + return 0, err + } + + sig, err := h.mem.ReadString(sigRegion) + if err != nil { + return 0, err + } + + pubkey, err := h.mem.ReadString(pubkeyRegion) + if err != nil { + return 0, err + } + + // Verify signature + ok, err := h.verifier.VerifySecp256k1Signature([]byte(hash), []byte(sig), []byte(pubkey)) + if err != nil { + return 0, err + } + + if ok { + return 1, nil + } + return 0, nil +} + +// VerifyEd25519 verifies an ed25519 signature +func (h *HostFunctions) VerifyEd25519(msgPtr, sigPtr, pubkeyPtr uint32) (uint32, error) { + // Read regions + msgRegion, err := h.mem.ReadRegion(msgPtr) + if err != nil { + return 0, err + } + + sigRegion, err := h.mem.ReadRegion(sigPtr) + if err != nil { + return 0, err + } + + pubkeyRegion, err := h.mem.ReadRegion(pubkeyPtr) + if err != nil { + return 0, err + } + + // Read data from memory + msg, err := h.mem.ReadString(msgRegion) + if err != nil { + return 0, err + } + + sig, err := h.mem.ReadString(sigRegion) + if err != nil { + return 0, err + } + + pubkey, err := h.mem.ReadString(pubkeyRegion) + if err != nil { + return 0, err + } + + // Verify signature + ok, err := h.verifier.VerifyEd25519Signature([]byte(msg), []byte(sig), []byte(pubkey)) + if err != nil { + return 0, err + } + + if ok { + return 1, nil + } + return 0, nil +} + +// Hash computes various cryptographic hashes +func (h *HostFunctions) Hash(hashType uint32, msgPtr uint32) (uint32, error) { + // Read message region + msgRegion, err := h.mem.ReadRegion(msgPtr) + if err != nil { + return 0, err + } + + // Read message from memory + msg, err := h.mem.ReadString(msgRegion) + if err != nil { + return 0, err + } + + // Compute hash based on type + var hash []byte + switch hashType { + case 1: // SHA256 + hash = h.verifier.SHA256([]byte(msg)) + case 2: // SHA512 + hash = h.verifier.SHA512([]byte(msg)) + default: + return 0, fmt.Errorf("unsupported hash type: %d", hashType) + } + + // Write hash to memory and return pointer + ptr, _, err := h.allocator.Allocate(hash) + if err != nil { + return 0, err + } + + return ptr, nil +} diff --git a/internal/runtime/crypto/host_bls.go b/internal/runtime/crypto/host_bls.go new file mode 100644 index 00000000..7280ccbd --- /dev/null +++ b/internal/runtime/crypto/host_bls.go @@ -0,0 +1,210 @@ +package crypto + +import ( + "encoding/binary" + "fmt" +) + +// BLS12381AggregateG1Host implements the BLS12-381 G1 aggregation host function +func (h *HostFunctions) BLS12381AggregateG1Host(elementsPtr uint32) (uint32, error) { + // Read length prefix (4 bytes) + lenBytes, err := h.mem.Manager().ReadBytes(elementsPtr, 4) + if err != nil { + return 0, fmt.Errorf("failed to read elements length: %w", err) + } + numElements := binary.LittleEndian.Uint32(lenBytes) + + // Read elements + elements := make([][]byte, numElements) + offset := elementsPtr + 4 + for i := uint32(0); i < numElements; i++ { + // Read element length + elemLenBytes, err := h.mem.Manager().ReadBytes(offset, 4) + if err != nil { + return 0, fmt.Errorf("failed to read element length: %w", err) + } + elemLen := binary.LittleEndian.Uint32(elemLenBytes) + offset += 4 + + // Read element data + element, err := h.mem.Manager().ReadBytes(offset, elemLen) + if err != nil { + return 0, fmt.Errorf("failed to read element data: %w", err) + } + elements[i] = element + offset += elemLen + } + + // Perform aggregation + result, err := h.verifier.BLS12381AggregateG1(elements) + if err != nil { + return 0, fmt.Errorf("failed to aggregate G1 points: %w", err) + } + + // Write result to memory + ptr, _, err := h.allocator.Allocate(result) + if err != nil { + return 0, fmt.Errorf("failed to allocate memory for result: %w", err) + } + + return ptr, nil +} + +// BLS12381AggregateG2Host implements the BLS12-381 G2 aggregation host function +func (h *HostFunctions) BLS12381AggregateG2Host(elementsPtr uint32) (uint32, error) { + // Read length prefix (4 bytes) + lenBytes, err := h.mem.Manager().ReadBytes(elementsPtr, 4) + if err != nil { + return 0, fmt.Errorf("failed to read elements length: %w", err) + } + numElements := binary.LittleEndian.Uint32(lenBytes) + + // Read elements + elements := make([][]byte, numElements) + offset := elementsPtr + 4 + for i := uint32(0); i < numElements; i++ { + // Read element length + elemLenBytes, err := h.mem.Manager().ReadBytes(offset, 4) + if err != nil { + return 0, fmt.Errorf("failed to read element length: %w", err) + } + elemLen := binary.LittleEndian.Uint32(elemLenBytes) + offset += 4 + + // Read element data + element, err := h.mem.Manager().ReadBytes(offset, elemLen) + if err != nil { + return 0, fmt.Errorf("failed to read element data: %w", err) + } + elements[i] = element + offset += elemLen + } + + // Perform aggregation + result, err := h.verifier.BLS12381AggregateG2(elements) + if err != nil { + return 0, fmt.Errorf("failed to aggregate G2 points: %w", err) + } + + // Write result to memory + ptr, _, err := h.allocator.Allocate(result) + if err != nil { + return 0, fmt.Errorf("failed to allocate memory for result: %w", err) + } + + return ptr, nil +} + +// BLS12381HashToG1Host implements the BLS12-381 hash-to-G1 host function +func (h *HostFunctions) BLS12381HashToG1Host(msgPtr uint32) (uint32, error) { + // Read message region + msgRegion, err := h.mem.ReadRegion(msgPtr) + if err != nil { + return 0, err + } + + // Read message from memory + msg, err := h.mem.ReadString(msgRegion) + if err != nil { + return 0, err + } + + // Hash to G1 + result, err := h.verifier.BLS12381HashToG1([]byte(msg)) + if err != nil { + return 0, fmt.Errorf("failed to hash to G1: %w", err) + } + + // Write result to memory + ptr, _, err := h.allocator.Allocate(result) + if err != nil { + return 0, fmt.Errorf("failed to allocate memory for result: %w", err) + } + + return ptr, nil +} + +// BLS12381HashToG2Host implements the BLS12-381 hash-to-G2 host function +func (h *HostFunctions) BLS12381HashToG2Host(msgPtr uint32) (uint32, error) { + // Read message region + msgRegion, err := h.mem.ReadRegion(msgPtr) + if err != nil { + return 0, err + } + + // Read message from memory + msg, err := h.mem.ReadString(msgRegion) + if err != nil { + return 0, err + } + + // Hash to G2 + result, err := h.verifier.BLS12381HashToG2([]byte(msg)) + if err != nil { + return 0, fmt.Errorf("failed to hash to G2: %w", err) + } + + // Write result to memory + ptr, _, err := h.allocator.Allocate(result) + if err != nil { + return 0, fmt.Errorf("failed to allocate memory for result: %w", err) + } + + return ptr, nil +} + +// BLS12381PairingEqualityHost implements the BLS12-381 pairing equality check host function +func (h *HostFunctions) BLS12381PairingEqualityHost(a1Ptr, a2Ptr, b1Ptr, b2Ptr uint32) (uint32, error) { + // Read all regions + a1Region, err := h.mem.ReadRegion(a1Ptr) + if err != nil { + return 0, err + } + + a2Region, err := h.mem.ReadRegion(a2Ptr) + if err != nil { + return 0, err + } + + b1Region, err := h.mem.ReadRegion(b1Ptr) + if err != nil { + return 0, err + } + + b2Region, err := h.mem.ReadRegion(b2Ptr) + if err != nil { + return 0, err + } + + // Read all points from memory + a1, err := h.mem.ReadString(a1Region) + if err != nil { + return 0, err + } + + a2, err := h.mem.ReadString(a2Region) + if err != nil { + return 0, err + } + + b1, err := h.mem.ReadString(b1Region) + if err != nil { + return 0, err + } + + b2, err := h.mem.ReadString(b2Region) + if err != nil { + return 0, err + } + + // Check pairing equality + result, err := h.verifier.BLS12381PairingEquality([]byte(a1), []byte(a2), []byte(b1), []byte(b2)) + if err != nil { + return 0, fmt.Errorf("failed to check pairing equality: %w", err) + } + + if result { + return 1, nil + } + return 0, nil +} diff --git a/internal/runtime/crypto/host_secp256r1.go b/internal/runtime/crypto/host_secp256r1.go new file mode 100644 index 00000000..55104fa0 --- /dev/null +++ b/internal/runtime/crypto/host_secp256r1.go @@ -0,0 +1,90 @@ +package crypto + +import ( + "fmt" +) + +// Secp256r1VerifyHost implements the secp256r1 signature verification host function +func (h *HostFunctions) Secp256r1VerifyHost(hashPtr, sigPtr, pubkeyPtr uint32) (uint32, error) { + // Read regions + hashRegion, err := h.mem.ReadRegion(hashPtr) + if err != nil { + return 0, err + } + + sigRegion, err := h.mem.ReadRegion(sigPtr) + if err != nil { + return 0, err + } + + pubkeyRegion, err := h.mem.ReadRegion(pubkeyPtr) + if err != nil { + return 0, err + } + + // Read data from memory + hash, err := h.mem.ReadString(hashRegion) + if err != nil { + return 0, err + } + + sig, err := h.mem.ReadString(sigRegion) + if err != nil { + return 0, err + } + + pubkey, err := h.mem.ReadString(pubkeyRegion) + if err != nil { + return 0, err + } + + // Verify signature + ok, err := h.verifier.Secp256r1Verify([]byte(hash), []byte(sig), []byte(pubkey)) + if err != nil { + return 0, fmt.Errorf("failed to verify secp256r1 signature: %w", err) + } + + if ok { + return 1, nil + } + return 0, nil +} + +// Secp256r1RecoverPubkeyHost implements the secp256r1 public key recovery host function +func (h *HostFunctions) Secp256r1RecoverPubkeyHost(hashPtr uint32, sigPtr uint32, recoveryParam uint32) (uint32, error) { + // Read regions + hashRegion, err := h.mem.ReadRegion(hashPtr) + if err != nil { + return 0, err + } + + sigRegion, err := h.mem.ReadRegion(sigPtr) + if err != nil { + return 0, err + } + + // Read data from memory + hash, err := h.mem.ReadString(hashRegion) + if err != nil { + return 0, err + } + + sig, err := h.mem.ReadString(sigRegion) + if err != nil { + return 0, err + } + + // Try to recover public key + pubkey, err := h.verifier.Secp256r1RecoverPubkey([]byte(hash), []byte(sig), byte(recoveryParam)) + if err != nil { + return 0, fmt.Errorf("failed to recover secp256r1 public key: %w", err) + } + + // Write result to memory + ptr, _, err := h.allocator.Allocate(pubkey) + if err != nil { + return 0, fmt.Errorf("failed to allocate memory for result: %w", err) + } + + return ptr, nil +} diff --git a/internal/runtime/crypto/secp256r1.go b/internal/runtime/crypto/secp256r1.go new file mode 100644 index 00000000..a653989c --- /dev/null +++ b/internal/runtime/crypto/secp256r1.go @@ -0,0 +1,52 @@ +package crypto + +import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "errors" + "fmt" + "math/big" +) + +// Secp256r1Verify verifies a P-256 ECDSA signature. +// hash is the message digest (NOT the preimage), +// signature should be 64 bytes (r and s concatenated), +// pubkey should be an uncompressed or compressed public key in standard format. +func (v *Verifier) Secp256r1Verify(hash, signature, pubkey []byte) (bool, error) { + // Parse public key using crypto/ecdh + curve := ecdh.P256() + key, err := curve.NewPublicKey(pubkey) + if err != nil { + return false, fmt.Errorf("invalid public key: %w", err) + } + + // Get the raw coordinates for ECDSA verification + rawKey := key.Bytes() + x, y := elliptic.UnmarshalCompressed(elliptic.P256(), rawKey) + if x == nil { + return false, errors.New("failed to parse public key coordinates") + } + + // Parse signature: must be exactly 64 bytes => r (first 32 bytes), s (second 32 bytes). + if len(signature) != 64 { + return false, fmt.Errorf("signature must be 64 bytes, got %d", len(signature)) + } + r := new(big.Int).SetBytes(signature[:32]) + s := new(big.Int).SetBytes(signature[32:64]) + + pub := &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: x, + Y: y, + } + + verified := ecdsa.Verify(pub, hash, r, s) + return verified, nil +} + +// Secp256r1RecoverPubkey tries to recover a P-256 public key from a signature. +// This is a placeholder as P-256 doesn't support standard pubkey recovery. +func (v *Verifier) Secp256r1RecoverPubkey(hash, signature []byte, recovery byte) ([]byte, error) { + return nil, errors.New("secp256r1 public key recovery not supported") +} diff --git a/internal/runtime/db/host.go b/internal/runtime/db/host.go new file mode 100644 index 00000000..e6b7f5e2 --- /dev/null +++ b/internal/runtime/db/host.go @@ -0,0 +1,100 @@ +package db + +import ( + "github.com/CosmWasm/wasmvm/v2/internal/runtime/memory" + "github.com/CosmWasm/wasmvm/v2/types" +) + +// HostFunctions provides database-related host functions +type HostFunctions struct { + store *Store + mem *memory.HostFunctions + allocator *memory.Allocator +} + +// NewHostFunctions creates a new set of database host functions +func NewHostFunctions(store types.KVStore, mem *memory.HostFunctions, allocator *memory.Allocator) *HostFunctions { + return &HostFunctions{ + store: New(store), + mem: mem, + allocator: allocator, + } +} + +// Get reads a value from the store +func (h *HostFunctions) Get(keyPtr uint32) (uint32, error) { + // Read key region + keyRegion, err := h.mem.ReadRegion(keyPtr) + if err != nil { + return 0, err + } + + // Read key from memory + key, err := h.mem.ReadString(keyRegion) + if err != nil { + return 0, err + } + + // Get value from store + value := h.store.Get([]byte(key)) + if value == nil { + return 0, nil + } + + // Write value to memory and return pointer + ptr, _, err := h.allocator.Allocate(value) + if err != nil { + return 0, err + } + + return ptr, nil +} + +// Set writes a value to the store +func (h *HostFunctions) Set(keyPtr, valuePtr uint32) error { + // Read key region + keyRegion, err := h.mem.ReadRegion(keyPtr) + if err != nil { + return err + } + + // Read value region + valueRegion, err := h.mem.ReadRegion(valuePtr) + if err != nil { + return err + } + + // Read key and value from memory + key, err := h.mem.ReadString(keyRegion) + if err != nil { + return err + } + + value, err := h.mem.ReadString(valueRegion) + if err != nil { + return err + } + + // Set value in store + h.store.Set([]byte(key), []byte(value)) + return nil +} + +// Delete removes a key from the store +func (h *HostFunctions) Delete(keyPtr uint32) error { + // Read key region + keyRegion, err := h.mem.ReadRegion(keyPtr) + if err != nil { + return err + } + + // Read key from memory + key, err := h.mem.ReadString(keyRegion) + if err != nil { + return err + } + + // Delete from store + h.store.Delete([]byte(key)) + return nil +} diff --git a/internal/runtime/db/store.go b/internal/runtime/db/store.go new file mode 100644 index 00000000..032108fc --- /dev/null +++ b/internal/runtime/db/store.go @@ -0,0 +1,88 @@ +package db + +import ( + "bytes" + "sync" + + "github.com/CosmWasm/wasmvm/v2/types" +) + +// Store implements a thread-safe key-value store +type Store struct { + mu sync.RWMutex + store types.KVStore +} + +// New creates a new store instance +func New(store types.KVStore) *Store { + return &Store{ + store: store, + } +} + +// Get retrieves a value by key +func (s *Store) Get(key []byte) []byte { + s.mu.RLock() + defer s.mu.RUnlock() + return s.store.Get(key) +} + +// Set stores a key-value pair +func (s *Store) Set(key, value []byte) { + s.mu.Lock() + defer s.mu.Unlock() + s.store.Set(key, value) +} + +// Delete removes a key-value pair +func (s *Store) Delete(key []byte) { + s.mu.Lock() + defer s.mu.Unlock() + s.store.Delete(key) +} + +// Iterator creates an iterator over a domain of keys +func (s *Store) Iterator(start, end []byte) types.Iterator { + s.mu.RLock() + defer s.mu.RUnlock() + return s.store.Iterator(start, end) +} + +// ReverseIterator creates a reverse iterator over a domain of keys +func (s *Store) ReverseIterator(start, end []byte) types.Iterator { + s.mu.RLock() + defer s.mu.RUnlock() + return s.store.ReverseIterator(start, end) +} + +// PrefixIterator creates an iterator over a domain of keys with a prefix +func (s *Store) PrefixIterator(prefix []byte) types.Iterator { + end := calculatePrefixEnd(prefix) + return s.Iterator(prefix, end) +} + +// ReversePrefixIterator creates a reverse iterator over a domain of keys with a prefix +func (s *Store) ReversePrefixIterator(prefix []byte) types.Iterator { + end := calculatePrefixEnd(prefix) + return s.ReverseIterator(prefix, end) +} + +// calculatePrefixEnd returns the end key for prefix iteration +func calculatePrefixEnd(prefix []byte) []byte { + if len(prefix) == 0 { + return nil + } + + end := make([]byte, len(prefix)) + copy(end, prefix) + + for i := len(end) - 1; i >= 0; i-- { + end[i]++ + if end[i] != 0 { + return end[:i+1] + } + } + + // If we got here, we had a prefix of all 0xff values + return bytes.Repeat([]byte{0xff}, len(prefix)) +} diff --git a/internal/runtime/error/errors.go b/internal/runtime/error/errors.go new file mode 100644 index 00000000..a2f5c3bd --- /dev/null +++ b/internal/runtime/error/errors.go @@ -0,0 +1,33 @@ +package error + +import ( + "fmt" +) + +// RuntimeError represents a generic runtime error +type RuntimeError struct { + Msg string + Err error +} + +func (e *RuntimeError) Error() string { + if e.Err != nil { + return fmt.Sprintf("%s: %v", e.Msg, e.Err) + } + return e.Msg +} + +// GasError represents an error related to gas consumption +type GasError struct { + Wanted uint64 + Available uint64 +} + +func (e *GasError) Error() string { + return fmt.Sprintf("insufficient gas: required %d, but only %d available", e.Wanted, e.Available) +} + +// ToWasmVMError converts internal errors to wasmvm types.SystemError +func ToWasmVMError(err error) error { + return err +} diff --git a/internal/runtime/gas/meter.go b/internal/runtime/gas/meter.go new file mode 100644 index 00000000..36a2e9ce --- /dev/null +++ b/internal/runtime/gas/meter.go @@ -0,0 +1,58 @@ +package gas + +import ( + rterrors "github.com/CosmWasm/wasmvm/v2/internal/runtime/error" +) + +// Meter tracks gas consumption during contract execution +type Meter interface { + // Consume charges the specified amount of gas + Consume(amount uint64) error + // Remaining returns the amount of gas left + Remaining() uint64 + // HasGas checks if there is any gas left + HasGas() bool +} + +// DefaultMeter is the default implementation of Meter +type DefaultMeter struct { + limit uint64 + consumed uint64 +} + +// NewDefaultMeter creates a new gas meter with the specified limit +func NewDefaultMeter(limit uint64) *DefaultMeter { + return &DefaultMeter{ + limit: limit, + consumed: 0, + } +} + +func (m *DefaultMeter) Consume(amount uint64) error { + if m.consumed+amount > m.limit { + return &rterrors.GasError{ + Wanted: amount, + Available: m.Remaining(), + } + } + m.consumed += amount + return nil +} + +func (m *DefaultMeter) Remaining() uint64 { + if m.consumed >= m.limit { + return 0 + } + return m.limit - m.consumed +} + +func (m *DefaultMeter) HasGas() bool { + return m.Remaining() > 0 +} + +// Report contains information about gas usage +type Report struct { + Limit uint64 + Remaining uint64 + Used uint64 +} diff --git a/internal/runtime/hostcrypto.go b/internal/runtime/hostcrypto.go deleted file mode 100644 index 656560e3..00000000 --- a/internal/runtime/hostcrypto.go +++ /dev/null @@ -1,271 +0,0 @@ -package runtime - -import ( - "context" - "encoding/binary" - "fmt" - - "github.com/tetratelabs/wazero/api" -) - -// hostBls12381AggregateG1 implements bls12_381_aggregate_g1 -func hostBls12381AggregateG1(ctx context.Context, mod api.Module, elementsPtr uint32) (uint32, uint32) { - mem := mod.Memory() - - // Read length prefix (4 bytes) - lenBytes, err := readMemory(mem, elementsPtr, 4) - if err != nil { - panic(fmt.Sprintf("failed to read elements length: %v", err)) - } - numElements := binary.LittleEndian.Uint32(lenBytes) - - // Read elements - elements := make([][]byte, numElements) - offset := elementsPtr + 4 - for i := uint32(0); i < numElements; i++ { - // Read element length - elemLenBytes, err := readMemory(mem, offset, 4) - if err != nil { - panic(fmt.Sprintf("failed to read element length: %v", err)) - } - elemLen := binary.LittleEndian.Uint32(elemLenBytes) - offset += 4 - - // Read element data - element, err := readMemory(mem, offset, elemLen) - if err != nil { - panic(fmt.Sprintf("failed to read element data: %v", err)) - } - elements[i] = element - offset += elemLen - } - - // Perform aggregation - result, err := BLS12381AggregateG1(elements) - if err != nil { - panic(fmt.Sprintf("failed to aggregate G1 points: %v", err)) - } - - // Allocate memory for result - resultPtr, err := allocateInContract(ctx, mod, uint32(len(result))) - if err != nil { - panic(fmt.Sprintf("failed to allocate memory for result: %v", err)) - } - - // Write result - if err := writeMemory(mem, resultPtr, result); err != nil { - panic(fmt.Sprintf("failed to write result: %v", err)) - } - - return resultPtr, uint32(len(result)) -} - -// hostBls12381AggregateG2 implements bls12_381_aggregate_g2 -func hostBls12381AggregateG2(ctx context.Context, mod api.Module, elementsPtr uint32) (uint32, uint32) { - mem := mod.Memory() - - // Read length prefix (4 bytes) - lenBytes, err := readMemory(mem, elementsPtr, 4) - if err != nil { - panic(fmt.Sprintf("failed to read elements length: %v", err)) - } - numElements := binary.LittleEndian.Uint32(lenBytes) - - // Read elements - elements := make([][]byte, numElements) - offset := elementsPtr + 4 - for i := uint32(0); i < numElements; i++ { - // Read element length - elemLenBytes, err := readMemory(mem, offset, 4) - if err != nil { - panic(fmt.Sprintf("failed to read element length: %v", err)) - } - elemLen := binary.LittleEndian.Uint32(elemLenBytes) - offset += 4 - - // Read element data - element, err := readMemory(mem, offset, elemLen) - if err != nil { - panic(fmt.Sprintf("failed to read element data: %v", err)) - } - elements[i] = element - offset += elemLen - } - - // Perform aggregation - result, err := BLS12381AggregateG2(elements) - if err != nil { - panic(fmt.Sprintf("failed to aggregate G2 points: %v", err)) - } - - // Allocate memory for result - resultPtr, err := allocateInContract(ctx, mod, uint32(len(result))) - if err != nil { - panic(fmt.Sprintf("failed to allocate memory for result: %v", err)) - } - - // Write result - if err := writeMemory(mem, resultPtr, result); err != nil { - panic(fmt.Sprintf("failed to write result: %v", err)) - } - - return resultPtr, uint32(len(result)) -} - -// hostBls12381HashToG1 implements bls12_381_hash_to_g1 -func hostBls12381HashToG1(ctx context.Context, mod api.Module, hashPtr, hashLen uint32) (uint32, uint32) { - mem := mod.Memory() - - // Read hash - hash, err := readMemory(mem, hashPtr, hashLen) - if err != nil { - panic(fmt.Sprintf("failed to read hash: %v", err)) - } - - // Perform hash-to-curve - result, err := BLS12381HashToG1(hash) - if err != nil { - panic(fmt.Sprintf("failed to hash to G1: %v", err)) - } - - // Allocate memory for result - resultPtr, err := allocateInContract(ctx, mod, uint32(len(result))) - if err != nil { - panic(fmt.Sprintf("failed to allocate memory for result: %v", err)) - } - - // Write result - if err := writeMemory(mem, resultPtr, result); err != nil { - panic(fmt.Sprintf("failed to write result: %v", err)) - } - - return resultPtr, uint32(len(result)) -} - -// hostBls12381HashToG2 implements bls12_381_hash_to_g2 -func hostBls12381HashToG2(ctx context.Context, mod api.Module, hashPtr, hashLen uint32) (uint32, uint32) { - mem := mod.Memory() - - // Read hash - hash, err := readMemory(mem, hashPtr, hashLen) - if err != nil { - panic(fmt.Sprintf("failed to read hash: %v", err)) - } - - // Perform hash-to-curve - result, err := BLS12381HashToG2(hash) - if err != nil { - panic(fmt.Sprintf("failed to hash to G2: %v", err)) - } - - // Allocate memory for result - resultPtr, err := allocateInContract(ctx, mod, uint32(len(result))) - if err != nil { - panic(fmt.Sprintf("failed to allocate memory for result: %v", err)) - } - - // Write result - if err := writeMemory(mem, resultPtr, result); err != nil { - panic(fmt.Sprintf("failed to write result: %v", err)) - } - - return resultPtr, uint32(len(result)) -} - -// hostBls12381PairingEquality implements bls12_381_pairing_equality -func hostBls12381PairingEquality(ctx context.Context, mod api.Module, a1Ptr, a1Len, a2Ptr, a2Len, b1Ptr, b1Len, b2Ptr, b2Len uint32) uint32 { - mem := mod.Memory() - - // Read points - a1, err := readMemory(mem, a1Ptr, a1Len) - if err != nil { - panic(fmt.Sprintf("failed to read a1: %v", err)) - } - a2, err := readMemory(mem, a2Ptr, a2Len) - if err != nil { - panic(fmt.Sprintf("failed to read a2: %v", err)) - } - b1, err := readMemory(mem, b1Ptr, b1Len) - if err != nil { - panic(fmt.Sprintf("failed to read b1: %v", err)) - } - b2, err := readMemory(mem, b2Ptr, b2Len) - if err != nil { - panic(fmt.Sprintf("failed to read b2: %v", err)) - } - - // Check pairing equality - result, err := BLS12381PairingEquality(a1, a2, b1, b2) - if err != nil { - panic(fmt.Sprintf("failed to check pairing equality: %v", err)) - } - - if result { - return 1 - } - return 0 -} - -// hostSecp256r1Verify implements secp256r1_verify -func hostSecp256r1Verify(ctx context.Context, mod api.Module, hashPtr, hashLen, sigPtr, sigLen, pubkeyPtr, pubkeyLen uint32) uint32 { - mem := mod.Memory() - - // Read inputs - hash, err := readMemory(mem, hashPtr, hashLen) - if err != nil { - panic(fmt.Sprintf("failed to read hash: %v", err)) - } - signature, err := readMemory(mem, sigPtr, sigLen) - if err != nil { - panic(fmt.Sprintf("failed to read signature: %v", err)) - } - pubkey, err := readMemory(mem, pubkeyPtr, pubkeyLen) - if err != nil { - panic(fmt.Sprintf("failed to read public key: %v", err)) - } - - // Verify signature - result, err := Secp256r1Verify(hash, signature, pubkey) - if err != nil { - panic(fmt.Sprintf("failed to verify signature: %v", err)) - } - - if result { - return 1 - } - return 0 -} - -// hostSecp256r1RecoverPubkey implements secp256r1_recover_pubkey -func hostSecp256r1RecoverPubkey(ctx context.Context, mod api.Module, hashPtr, hashLen, sigPtr, sigLen, recovery uint32) (uint32, uint32) { - mem := mod.Memory() - - // Read inputs - hash, err := readMemory(mem, hashPtr, hashLen) - if err != nil { - panic(fmt.Sprintf("failed to read hash: %v", err)) - } - signature, err := readMemory(mem, sigPtr, sigLen) - if err != nil { - panic(fmt.Sprintf("failed to read signature: %v", err)) - } - - // Recover public key - pubkey, err := Secp256r1RecoverPubkey(hash, signature, byte(recovery)) - if err != nil { - panic(fmt.Sprintf("failed to recover public key: %v", err)) - } - - // Allocate memory for result - resultPtr, err := allocateInContract(ctx, mod, uint32(len(pubkey))) - if err != nil { - panic(fmt.Sprintf("failed to allocate memory for result: %v", err)) - } - - // Write result - if err := writeMemory(mem, resultPtr, pubkey); err != nil { - panic(fmt.Sprintf("failed to write result: %v", err)) - } - - return resultPtr, uint32(len(pubkey)) -} diff --git a/internal/runtime/hostfunctions.go b/internal/runtime/hostfunctions.go index e7639131..f4e5cc22 100644 --- a/internal/runtime/hostfunctions.go +++ b/internal/runtime/hostfunctions.go @@ -996,7 +996,7 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz builder.NewFunctionBuilder(). WithFunc(func(ctx context.Context, m api.Module, psPtr, qsPtr, rPtr, sPtr uint32) uint32 { ctx = context.WithValue(ctx, envKey, env) - return hostBls12381PairingEquality(ctx, m, psPtr, 0, qsPtr, 0, rPtr, 0, sPtr, 0) + return crypto.hostBls12381PairingEquality(ctx, m, psPtr, 0, qsPtr, 0, rPtr, 0, sPtr, 0) }). WithParameterNames("ps_ptr", "qs_ptr", "r_ptr", "s_ptr"). WithResultNames("result"). @@ -1005,7 +1005,7 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz builder.NewFunctionBuilder(). WithFunc(func(ctx context.Context, m api.Module, hashFunction, msgPtr, dstPtr, outPtr uint32) uint32 { ctx = context.WithValue(ctx, envKey, env) - ptr, _ := hostBls12381HashToG1(ctx, m, msgPtr, hashFunction) + ptr, _ := crypto.hostBls12381HashToG1(ctx, m, msgPtr, hashFunction) return ptr }). WithParameterNames("hash_function", "msg_ptr", "dst_ptr", "out_ptr"). @@ -1015,7 +1015,7 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz builder.NewFunctionBuilder(). WithFunc(func(ctx context.Context, m api.Module, hashFunction, msgPtr, dstPtr, outPtr uint32) uint32 { ctx = context.WithValue(ctx, envKey, env) - ptr, _ := hostBls12381HashToG2(ctx, m, msgPtr, hashFunction) + ptr, _ := crypto.hostBls12381HashToG2(ctx, m, msgPtr, hashFunction) return ptr }). WithParameterNames("hash_function", "msg_ptr", "dst_ptr", "out_ptr"). @@ -1026,7 +1026,7 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz builder.NewFunctionBuilder(). WithFunc(func(ctx context.Context, m api.Module, messageHashPtr, signaturePtr, publicKeyPtr uint32) uint32 { ctx = context.WithValue(ctx, envKey, env) - return hostSecp256r1Verify(ctx, m, messageHashPtr, 0, signaturePtr, 0, publicKeyPtr, 0) + return crypto.hostSecp256r1Verify(ctx, m, messageHashPtr, 0, signaturePtr, 0, publicKeyPtr, 0) }). WithParameterNames("message_hash_ptr", "signature_ptr", "public_key_ptr"). WithResultNames("result"). @@ -1035,7 +1035,7 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz builder.NewFunctionBuilder(). WithFunc(func(ctx context.Context, m api.Module, messageHashPtr, signaturePtr, recoveryParam uint32) uint64 { ctx = context.WithValue(ctx, envKey, env) - ptr, len := hostSecp256r1RecoverPubkey(ctx, m, messageHashPtr, 0, signaturePtr, 0, recoveryParam) + ptr, len := crypto.hostSecp256r1RecoverPubkey(ctx, m, messageHashPtr, 0, signaturePtr, 0, recoveryParam) return (uint64(len) << 32) | uint64(ptr) }). WithParameterNames("message_hash_ptr", "signature_ptr", "recovery_param"). @@ -1232,26 +1232,6 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz return builder.Compile(context.Background()) } -// When you instantiate a contract, you can do something like: -// -// compiledHost, err := RegisterHostFunctions(runtime, env) -// if err != nil { -// ... -// } -// _, err = runtime.InstantiateModule(ctx, compiledHost, wazero.NewModuleConfig()) -// if err != nil { -// ... -// } -// -// Then, instantiate your contract module which imports "env" module's functions. - -// contextKey is a custom type for context keys to avoid collisions -type contextKey string - -const ( - envKey contextKey = "env" -) - // hostNextKey implements db_next_key func hostNextKey(ctx context.Context, mod api.Module, callID, iterID uint64) (keyPtr, keyLen, errCode uint32) { env := ctx.Value("env").(*RuntimeEnvironment) diff --git a/internal/runtime/hostfunctions/register.go b/internal/runtime/hostfunctions/register.go new file mode 100644 index 00000000..adb7bad4 --- /dev/null +++ b/internal/runtime/hostfunctions/register.go @@ -0,0 +1,49 @@ +package hostfunctions + +import ( + "context" + "fmt" + + "github.com/CosmWasm/wasmvm/v2/internal/runtime" + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" +) + +// contextKey is a custom type for context keys to avoid collisions +type contextKey string + +const ( + envKey contextKey = "env" +) + +// RegisterHostFunctions registers all host functions with the wazero runtime +func RegisterHostFunctions(runtime wazero.Runtime, env *runtime.RuntimeEnvironment) (wazero.CompiledModule, error) { + builder := runtime.NewHostModuleBuilder("env") + + // Register abort function + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, code uint32) { + ctx = context.WithValue(ctx, envKey, env) + panic(fmt.Sprintf("Wasm contract aborted with code: %d (0x%x)", code, code)) + }). + WithParameterNames("code"). + Export("abort") + + // Register BLS12-381 functions + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, g1sPtr, outPtr uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + ptr, err := env.Crypto.BLS12381AggregateG1Host(g1sPtr) + if err != nil { + panic(err) + } + return ptr + }). + WithParameterNames("g1s_ptr", "out_ptr"). + WithResultNames("result"). + Export("bls12_381_aggregate_g1") + + // Add other host functions here... + + return builder.Compile(context.Background()) +} diff --git a/internal/runtime/iterator/iterator.go b/internal/runtime/iterator/iterator.go new file mode 100644 index 00000000..79b2c33c --- /dev/null +++ b/internal/runtime/iterator/iterator.go @@ -0,0 +1,64 @@ +package iterator + +import ( + "sync" + + "github.com/CosmWasm/wasmvm/v2/types" +) + +// Manager handles database iterators +type Manager struct { + mu sync.RWMutex + iterators map[uint64]types.Iterator + nextID uint64 +} + +// New creates a new iterator manager +func New() *Manager { + return &Manager{ + iterators: make(map[uint64]types.Iterator), + nextID: 1, + } +} + +// Create stores an iterator and returns its ID +func (m *Manager) Create(iter types.Iterator) uint64 { + m.mu.Lock() + defer m.mu.Unlock() + + id := m.nextID + m.iterators[id] = iter + m.nextID++ + return id +} + +// Get retrieves an iterator by its ID +func (m *Manager) Get(id uint64) (types.Iterator, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + iter, exists := m.iterators[id] + return iter, exists +} + +// Remove deletes an iterator +func (m *Manager) Remove(id uint64) { + m.mu.Lock() + defer m.mu.Unlock() + + if iter, exists := m.iterators[id]; exists { + iter.Close() + delete(m.iterators, id) + } +} + +// RemoveAll deletes all iterators +func (m *Manager) RemoveAll() { + m.mu.Lock() + defer m.mu.Unlock() + + for _, iter := range m.iterators { + iter.Close() + } + m.iterators = make(map[uint64]types.Iterator) +} diff --git a/internal/runtime/memory/allocator.go b/internal/runtime/memory/allocator.go new file mode 100644 index 00000000..b190b630 --- /dev/null +++ b/internal/runtime/memory/allocator.go @@ -0,0 +1,122 @@ +package memory + +import ( + "context" + "fmt" + + "github.com/tetratelabs/wazero/api" +) + +// Allocator handles memory allocation in Wasm modules +type Allocator struct { + memory api.Memory + module api.Module + manager *Manager + host *HostFunctions +} + +// NewAllocator creates a new memory allocator +func NewAllocator(memory api.Memory, module api.Module) *Allocator { + manager := New(memory) + return &Allocator{ + memory: memory, + module: module, + manager: manager, + host: NewHostFunctions(memory), + } +} + +// Allocate allocates memory and writes data to it +func (a *Allocator) Allocate(data []byte) (uint32, uint32, error) { + if data == nil { + return 0, 0, nil + } + + // Get the allocate function + allocate := a.module.ExportedFunction("allocate") + if allocate == nil { + return 0, 0, fmt.Errorf("allocate function not found in WASM module") + } + + // Allocate memory for the Region struct (12 bytes) and the data + size := uint32(len(data)) + results, err := allocate.Call(context.Background(), uint64(size+regionSize)) + if err != nil { + return 0, 0, fmt.Errorf("failed to allocate memory: %w", err) + } + ptr := uint32(results[0]) + + // Create and write the Region struct + region := &Region{ + Offset: ptr + regionSize, // Data starts after the Region struct + Capacity: size, + Length: size, + } + + // Validate the region before writing + if err := validateRegion(region); err != nil { + if err := a.Deallocate(ptr); err != nil { + // Log deallocation error but return the original error + fmt.Printf("failed to deallocate memory after validation error: %v\n", err) + } + return 0, 0, fmt.Errorf("invalid region: %w", err) + } + + // Write the Region struct + if err := a.host.WriteRegion(ptr, region); err != nil { + if err := a.Deallocate(ptr); err != nil { + fmt.Printf("failed to deallocate memory after write error: %v\n", err) + } + return 0, 0, fmt.Errorf("failed to write region: %w", err) + } + + // Write the actual data + if err := a.manager.WriteBytes(region.Offset, data); err != nil { + if err := a.Deallocate(ptr); err != nil { + fmt.Printf("failed to deallocate memory after data write error: %v\n", err) + } + return 0, 0, fmt.Errorf("failed to write data to memory: %w", err) + } + + return ptr, size, nil +} + +// Deallocate frees allocated memory +func (a *Allocator) Deallocate(ptr uint32) error { + deallocate := a.module.ExportedFunction("deallocate") + if deallocate == nil { + return fmt.Errorf("deallocate function not found in WASM module") + } + + _, err := deallocate.Call(context.Background(), uint64(ptr)) + if err != nil { + return fmt.Errorf("failed to deallocate memory: %w", err) + } + + return nil +} + +// Read reads data from allocated memory +func (a *Allocator) Read(ptr uint32) ([]byte, error) { + if ptr == 0 { + return nil, nil + } + + // Read the Region struct + region, err := a.host.ReadRegion(ptr) + if err != nil { + return nil, fmt.Errorf("failed to read region: %w", err) + } + + // Read the actual data + data, err := a.manager.ReadBytes(region.Offset, region.Length) + if err != nil { + return nil, fmt.Errorf("failed to read memory: %w", err) + } + + // Make a copy to ensure we own the data + result := make([]byte, len(data)) + copy(result, data) + + return result, nil +} diff --git a/internal/runtime/memory/errors.go b/internal/runtime/memory/errors.go new file mode 100644 index 00000000..a7ab8227 --- /dev/null +++ b/internal/runtime/memory/errors.go @@ -0,0 +1,12 @@ +package memory + +import "errors" + +var ( + // ErrInvalidMemoryAccess is returned when trying to access invalid memory regions + ErrInvalidMemoryAccess = errors.New("invalid memory access") + // ErrMemoryReadFailed is returned when memory read operation fails + ErrMemoryReadFailed = errors.New("memory read failed") + // ErrMemoryWriteFailed is returned when memory write operation fails + ErrMemoryWriteFailed = errors.New("memory write failed") +) diff --git a/internal/runtime/memory/host.go b/internal/runtime/memory/host.go new file mode 100644 index 00000000..fff79480 --- /dev/null +++ b/internal/runtime/memory/host.go @@ -0,0 +1,93 @@ +package memory + +import ( + "fmt" + "math" + + "github.com/tetratelabs/wazero/api" +) + +// Constants for memory management +const ( + // Memory page size in WebAssembly (64KB) + wasmPageSize = 65536 + + // Size of a Region struct in bytes (3x4 bytes) + regionSize = 12 +) + +// Region describes data allocated in Wasm's linear memory +type Region struct { + Offset uint32 + Capacity uint32 + Length uint32 +} + +// validateRegion performs plausibility checks on a Region +func validateRegion(region *Region) error { + if region.Offset == 0 { + return fmt.Errorf("region has zero offset") + } + if region.Length > region.Capacity { + return fmt.Errorf("region length %d exceeds capacity %d", region.Length, region.Capacity) + } + if uint64(region.Offset)+uint64(region.Capacity) > math.MaxUint32 { + return fmt.Errorf("region out of range: offset %d, capacity %d", region.Offset, region.Capacity) + } + return nil +} + +// HostFunctions provides memory-related host functions +type HostFunctions struct { + manager *Manager +} + +// NewHostFunctions creates a new set of memory host functions +func NewHostFunctions(memory api.Memory) *HostFunctions { + return &HostFunctions{ + manager: New(memory), + } +} + +// Manager returns the underlying memory manager +func (h *HostFunctions) Manager() *Manager { + return h.manager +} + +// ReadRegion reads a Region struct from Wasm memory +func (h *HostFunctions) ReadRegion(offset uint32) (*Region, error) { + data, err := h.manager.ReadBytes(offset, regionSize) + if err != nil { + return nil, err + } + + region := &Region{ + Offset: h.manager.ReadUint32FromBytes(data[0:4]), + Capacity: h.manager.ReadUint32FromBytes(data[4:8]), + Length: h.manager.ReadUint32FromBytes(data[8:12]), + } + + if err := validateRegion(region); err != nil { + return nil, err + } + + return region, nil +} + +// ReadString reads a string from a Region in Wasm memory +func (h *HostFunctions) ReadString(region *Region) (string, error) { + data, err := h.manager.ReadBytes(region.Offset, region.Length) + if err != nil { + return "", err + } + return string(data), nil +} + +// WriteRegion writes a Region struct to Wasm memory +func (h *HostFunctions) WriteRegion(offset uint32, region *Region) error { + data := make([]byte, regionSize) + h.manager.WriteUint32ToBytes(data[0:4], region.Offset) + h.manager.WriteUint32ToBytes(data[4:8], region.Capacity) + h.manager.WriteUint32ToBytes(data[8:12], region.Length) + return h.manager.WriteBytes(offset, data) +} diff --git a/internal/runtime/memory/manager.go b/internal/runtime/memory/manager.go new file mode 100644 index 00000000..765810d3 --- /dev/null +++ b/internal/runtime/memory/manager.go @@ -0,0 +1,97 @@ +package memory + +import ( + "encoding/binary" + "sync" + + "github.com/tetratelabs/wazero/api" +) + +// Manager handles memory operations for Wasm modules +type Manager struct { + mu sync.RWMutex + memory api.Memory +} + +// New creates a new memory manager +func New(memory api.Memory) *Manager { + return &Manager{ + memory: memory, + } +} + +// ReadBytes reads a byte slice from Wasm memory +func (m *Manager) ReadBytes(offset uint32, length uint32) ([]byte, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + // Check if the memory access is within bounds + if uint64(offset)+uint64(length) > uint64(m.memory.Size()) { + return nil, ErrInvalidMemoryAccess + } + + data, ok := m.memory.Read(offset, length) + if !ok { + return nil, ErrMemoryReadFailed + } + + return data, nil +} + +// WriteBytes writes a byte slice to Wasm memory +func (m *Manager) WriteBytes(offset uint32, data []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + + // Check if the memory access is within bounds + if uint64(offset)+uint64(len(data)) > uint64(m.memory.Size()) { + return ErrInvalidMemoryAccess + } + + ok := m.memory.Write(offset, data) + if !ok { + return ErrMemoryWriteFailed + } + + return nil +} + +// ReadUint32 reads a uint32 from Wasm memory +func (m *Manager) ReadUint32(offset uint32) (uint32, error) { + data, err := m.ReadBytes(offset, 4) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint32(data), nil +} + +// WriteUint32 writes a uint32 to Wasm memory +func (m *Manager) WriteUint32(offset uint32, value uint32) error { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, value) + return m.WriteBytes(offset, buf) +} + +// ReadString reads a string from Wasm memory +func (m *Manager) ReadString(offset uint32, length uint32) (string, error) { + data, err := m.ReadBytes(offset, length) + if err != nil { + return "", err + } + return string(data), nil +} + +// WriteString writes a string to Wasm memory +func (m *Manager) WriteString(offset uint32, s string) error { + return m.WriteBytes(offset, []byte(s)) +} + +// ReadUint32FromBytes reads a uint32 from a byte slice +func (m *Manager) ReadUint32FromBytes(data []byte) uint32 { + return binary.LittleEndian.Uint32(data) +} + +// WriteUint32ToBytes writes a uint32 to a byte slice +func (m *Manager) WriteUint32ToBytes(data []byte, value uint32) { + binary.LittleEndian.PutUint32(data, value) +} diff --git a/internal/runtime/register.go b/internal/runtime/register.go new file mode 100644 index 00000000..039241a5 --- /dev/null +++ b/internal/runtime/register.go @@ -0,0 +1,312 @@ +package runtime + +import ( + "context" + "fmt" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" +) + +// contextKey is a custom type for context keys to avoid collisions +type contextKey string + +const ( + envKey contextKey = "env" +) + +// RegisterHostFunctions registers all host functions with the wazero runtime +func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (wazero.CompiledModule, error) { + builder := runtime.NewHostModuleBuilder("env") + + // Register abort function + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, code uint32) { + ctx = context.WithValue(ctx, envKey, env) + panic(fmt.Sprintf("Wasm contract aborted with code: %d (0x%x)", code, code)) + }). + WithParameterNames("code"). + Export("abort") + + // Register BLS12-381 functions + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, g1sPtr, outPtr uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + ptr, err := env.crypto.BLS12381AggregateG1Host(g1sPtr) + if err != nil { + panic(err) + } + return ptr + }). + WithParameterNames("g1s_ptr", "out_ptr"). + WithResultNames("result"). + Export("bls12_381_aggregate_g1") + + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, g2sPtr, outPtr uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + ptr, err := env.crypto.BLS12381AggregateG2Host(g2sPtr) + if err != nil { + panic(err) + } + return ptr + }). + WithParameterNames("g2s_ptr", "out_ptr"). + WithResultNames("result"). + Export("bls12_381_aggregate_g2") + + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, psPtr, qsPtr, rPtr, sPtr uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + ptr, err := env.crypto.BLS12381PairingEqualityHost(psPtr, qsPtr, rPtr, sPtr) + if err != nil { + panic(err) + } + return ptr + }). + WithParameterNames("ps_ptr", "qs_ptr", "r_ptr", "s_ptr"). + WithResultNames("result"). + Export("bls12_381_pairing_equality") + + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, hashFunction, msgPtr, dstPtr, outPtr uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + ptr, err := env.crypto.BLS12381HashToG1Host(msgPtr) + if err != nil { + panic(err) + } + return ptr + }). + WithParameterNames("hash_function", "msg_ptr", "dst_ptr", "out_ptr"). + WithResultNames("result"). + Export("bls12_381_hash_to_g1") + + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, hashFunction, msgPtr, dstPtr, outPtr uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + ptr, err := env.crypto.BLS12381HashToG2Host(msgPtr) + if err != nil { + panic(err) + } + return ptr + }). + WithParameterNames("hash_function", "msg_ptr", "dst_ptr", "out_ptr"). + WithResultNames("result"). + Export("bls12_381_hash_to_g2") + + // SECP256r1 functions + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, messageHashPtr, signaturePtr, publicKeyPtr uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + ptr, err := env.crypto.Secp256r1VerifyHost(messageHashPtr, signaturePtr, publicKeyPtr) + if err != nil { + panic(err) + } + return ptr + }). + WithParameterNames("message_hash_ptr", "signature_ptr", "public_key_ptr"). + WithResultNames("result"). + Export("secp256r1_verify") + + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, messageHashPtr, signaturePtr, recoveryParam uint32) uint64 { + ctx = context.WithValue(ctx, envKey, env) + ptr, err := env.crypto.Secp256r1RecoverPubkeyHost(messageHashPtr, signaturePtr, recoveryParam) + if err != nil { + panic(err) + } + return uint64(ptr) + }). + WithParameterNames("message_hash_ptr", "signature_ptr", "recovery_param"). + WithResultNames("result"). + Export("secp256r1_recover_pubkey") + + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, startPtr, startLen, order uint32) uint32 { + // Get environment from context + env := ctx.Value(envKey).(*RuntimeEnvironment) + + // Charge gas for scan operation (gasCostIteratorCreate + 1 gas per byte scanned) + env.gasUsed += gasCostIteratorCreate + uint64(startLen) + if env.gasUsed > env.Gas.GasConsumed() { + panic("out of gas") + } + + return hostScan(ctx, m, startPtr, startLen, order) + }). + WithParameterNames("start_ptr", "start_len", "order"). + WithResultNames("iter_id"). + Export("db_scan") + + // db_next + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, iterID uint32) uint32 { + // Get environment from context + env := ctx.Value(envKey).(*RuntimeEnvironment) + + // Charge gas for next operation + env.gasUsed += gasCostIteratorNext + if env.gasUsed > env.Gas.GasConsumed() { + panic("out of gas") + } + + return hostNext(ctx, m, iterID) + }). + WithParameterNames("iter_id"). + WithResultNames("kv_region_ptr"). + Export("db_next") + + // db_next_value + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, iterID uint32) uint32 { + // Get environment from context + env := ctx.Value(envKey).(*RuntimeEnvironment) + + // Charge gas for next value operation + env.gasUsed += gasCostIteratorNext + if env.gasUsed > env.Gas.GasConsumed() { + panic("out of gas") + } + + // Extract call_id and iter_id from the packed uint32 + callID := uint64(iterID >> 16) + actualIterID := uint64(iterID & 0xFFFF) + ptr, _, _ := hostNextValue(ctx, m, callID, actualIterID) + return ptr + }). + WithParameterNames("iter_id"). + WithResultNames("value_ptr"). + Export("db_next_value") + + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, addrPtr, addrLen uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + return hostHumanizeAddress(ctx, m, addrPtr, addrLen) + }). + WithParameterNames("addr_ptr", "addr_len"). + WithResultNames("result"). + Export("addr_humanize") + + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, addrPtr uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + return hostValidateAddress(ctx, m, addrPtr) + }). + WithParameterNames("addr_ptr"). + WithResultNames("result"). + Export("addr_validate") + + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, addrPtr, addrLen uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + return hostCanonicalizeAddress(ctx, m, addrPtr, addrLen) + }). + WithParameterNames("addr_ptr", "addr_len"). + WithResultNames("result"). + Export("addr_canonicalize") + + // Register Query functions + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, reqPtr, reqLen, gasLimit uint32) (uint32, uint32) { + ctx = context.WithValue(ctx, envKey, env) + return hostQueryExternal(ctx, m, reqPtr, reqLen, gasLimit) + }). + WithParameterNames("req_ptr", "req_len", "gas_limit"). + Export("querier_query") + + // Register secp256k1_verify function + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, hash_ptr, sig_ptr, pubkey_ptr uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + return hostSecp256k1Verify(ctx, m, hash_ptr, sig_ptr, pubkey_ptr) + }). + WithParameterNames("hash_ptr", "sig_ptr", "pubkey_ptr"). + WithResultNames("result"). + Export("secp256k1_verify") + + // Register DB read/write/remove functions + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, keyPtr uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + return hostDbRead(ctx, m, keyPtr) + }). + WithParameterNames("key_ptr"). + Export("db_read") + + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, keyPtr, valuePtr uint32) { + ctx = context.WithValue(ctx, envKey, env) + hostDbWrite(ctx, m, keyPtr, valuePtr) + }). + WithParameterNames("key_ptr", "value_ptr"). + Export("db_write") + + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, keyPtr uint32) { + ctx = context.WithValue(ctx, envKey, env) + hostDbRemove(ctx, m, keyPtr) + }). + WithParameterNames("key_ptr"). + Export("db_remove") + + // db_close_iterator + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, callID, iterID uint64) { + ctx = context.WithValue(ctx, envKey, env) + hostCloseIterator(ctx, m, callID, iterID) + }). + WithParameterNames("call_id", "iter_id"). + Export("db_close_iterator") + + // Register secp256k1_recover_pubkey function + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, hash_ptr, sig_ptr, rec_id uint32) uint64 { + ctx = context.WithValue(ctx, envKey, env) + return hostSecp256k1RecoverPubkey(ctx, m, hash_ptr, sig_ptr, rec_id) + }). + WithParameterNames("hash_ptr", "sig_ptr", "rec_id"). + WithResultNames("result"). + Export("secp256k1_recover_pubkey") + + // Register ed25519_verify function with i32i32i32_i32 signature + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, msg_ptr, sig_ptr, pubkey_ptr uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + return hostEd25519Verify(ctx, m, msg_ptr, sig_ptr, pubkey_ptr) + }). + WithParameterNames("msg_ptr", "sig_ptr", "pubkey_ptr"). + WithResultNames("result"). + Export("ed25519_verify") + + // Register ed25519_batch_verify function with i32i32i32_i32 signature + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, msgs_ptr, sigs_ptr, pubkeys_ptr uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + return hostEd25519BatchVerify(ctx, m, msgs_ptr, sigs_ptr, pubkeys_ptr) + }). + WithParameterNames("msgs_ptr", "sigs_ptr", "pubkeys_ptr"). + WithResultNames("result"). + Export("ed25519_batch_verify") + + // Register debug function with i32_v signature + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, msgPtr uint32) { + ctx = context.WithValue(ctx, envKey, env) + hostDebug(ctx, m, msgPtr) + }). + WithParameterNames("msg_ptr"). + Export("debug") + + // db_next_key + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, iterID uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + ptr, _, _ := hostNextKey(ctx, m, uint64(iterID), 0) + return ptr + }). + WithParameterNames("iter_id"). + WithResultNames("key_ptr"). + Export("db_next_key") + + return builder.Compile(context.Background()) +} diff --git a/internal/runtime/types/types.go b/internal/runtime/types/types.go new file mode 100644 index 00000000..31e7f362 --- /dev/null +++ b/internal/runtime/types/types.go @@ -0,0 +1,47 @@ +package types + +// DB represents a key-value store interface +type DB interface { + Get(key []byte) []byte + Set(key, value []byte) + Delete(key []byte) + Iterator(start, end []byte) Iterator +} + +// Iterator represents a database iterator +type Iterator interface { + Valid() bool + Key() []byte + Value() []byte + Next() + Close() +} + +// API represents the contract API interface +type API interface { + ValidateAddress(addr string) (uint64, error) +} + +// GasMeter represents a gas meter interface +type GasMeter interface { + ConsumeGas(amount uint64) + GasConsumed() uint64 + GasLimit() uint64 +} + +// KVStore represents a key-value store interface +type KVStore interface { + Get(key []byte) []byte + Set(key, value []byte) + Delete(key []byte) +} + +// Querier represents a query interface +type Querier interface { + Query(request []byte) ([]byte, error) +} + +// GoAPI represents the Go API interface +type GoAPI interface { + ValidateAddress(addr string) (uint64, error) +} diff --git a/internal/runtime/wazeroruntime.go b/internal/runtime/wazeroruntime.go index fc9c6fec..05f13eb1 100644 --- a/internal/runtime/wazeroruntime.go +++ b/internal/runtime/wazeroruntime.go @@ -15,6 +15,8 @@ import ( "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" + "github.com/CosmWasm/wasmvm/v2/internal/runtime/crypto" + "github.com/CosmWasm/wasmvm/v2/internal/runtime/memory" "github.com/CosmWasm/wasmvm/v2/types" ) @@ -34,6 +36,7 @@ type WazeroRuntime struct { kvStore types.KVStore api *types.GoAPI querier types.Querier + module api.Module } type RuntimeEnvironment struct { @@ -51,6 +54,10 @@ type RuntimeEnvironment struct { iterators map[uint64]map[uint64]types.Iterator nextIterID uint64 nextCallID uint64 + + // Host functions + Crypto *crypto.HostFunctions + Memory *memory.HostFunctions } // Constants for memory management @@ -1469,3 +1476,10 @@ func (w *WazeroRuntime) SimulateStoreCode(code []byte) ([]byte, error, bool) { // Return checksum, no error, and persisted=false return checksum[:], nil, false } + +// StartContract starts a new contract execution +func (r *WazeroRuntime) StartContract(db types.KVStore, api *types.GoAPI, querier types.Querier, gasLimit uint64) (*RuntimeEnvironment, error) { + env := NewRuntimeEnvironment(db, api, querier) + env.gasLimit = gasLimit + return env, nil +}