Skip to content

Commit

Permalink
feat: add TLS support
Browse files Browse the repository at this point in the history
This adds TLS support. An abstraction for the redis transport `RedisTransport` has been added. The existing TCP transport has been moved in as an implementation `TCPTransport <: RedisTransport`. A new TLS transport has been added as `TLSTransport <: RedisTransport`.

This can in future be extended to support unix sockets too (ref: #84)

fixes: #87
  • Loading branch information
tanmaykm committed Dec 28, 2023
1 parent fcd0353 commit a80f2ad
Show file tree
Hide file tree
Showing 12 changed files with 650 additions and 468 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ version = "2.0.0"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
MbedTLS = "739be429-bea8-5141-9913-cc70e7f3736d"

[compat]
julia = "^1"
DataStructures = "^0.18"
MbedTLS = "0.6.8, 0.7, 1"

[extras]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
2 changes: 2 additions & 0 deletions src/Redis.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Redis
using Dates
using Sockets
using MbedTLS

import Base.get, Base.keys, Base.time

Expand Down Expand Up @@ -59,6 +60,7 @@ export sentinel_masters, sentinel_master, sentinel_slaves, sentinel_getmasteradd
export REDIS_PERSISTENT_KEY, REDIS_EXPIRED_KEY

include("exceptions.jl")
include("transport/transport.jl")
include("connection.jl")
include("parser.jl")
include("client.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ end
function subscription_loop(conn::SubscriptionConnection, err_callback::Function)
while is_connected(conn)
try
l = getline(conn.socket)
reply = parseline(l, conn.socket)
l = getline(conn.transport)
reply = parseline(l, conn.transport)
reply = convert_reply(reply)
message = SubscriptionMessage(reply)
if message.message_type == SubscriptionMessageType.Message
Expand Down
85 changes: 53 additions & 32 deletions src/connection.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import Sockets.connect, Sockets.TCPSocket, Base.StatusActive, Base.StatusOpen, Base.StatusPaused

abstract type RedisConnectionBase end
abstract type SubscribableConnection<:RedisConnectionBase end

Expand All @@ -8,31 +6,31 @@ struct RedisConnection <: SubscribableConnection
port::Integer
password::AbstractString
db::Integer
socket::TCPSocket
transport::Transport.RedisTransport
end

struct SentinelConnection <: SubscribableConnection
host::AbstractString
port::Integer
password::AbstractString
db::Integer
socket::TCPSocket
transport::Transport.RedisTransport
end

struct TransactionConnection <: RedisConnectionBase
host::AbstractString
port::Integer
password::AbstractString
db::Integer
socket::TCPSocket
transport::Transport.RedisTransport
end

mutable struct PipelineConnection <: RedisConnectionBase
host::AbstractString
port::Integer
password::AbstractString
db::Integer
socket::TCPSocket
transport::Transport.RedisTransport
num_commands::Integer
end

Expand All @@ -43,77 +41,100 @@ struct SubscriptionConnection <: RedisConnectionBase
db::Integer
callbacks::Dict{AbstractString, Function}
pcallbacks::Dict{AbstractString, Function}
socket::TCPSocket
transport::Transport.RedisTransport
end

function RedisConnection(; host="127.0.0.1", port=6379, password="", db=0)
Transport.get_sslconfig(s::RedisConnectionBase) = Transport.get_sslconfig(s.transport)

function RedisConnection(; host="127.0.0.1", port=6379, password="", db=0, sslconfig=nothing)
try
socket = connect(host, port)
connection = RedisConnection(host, port, password, db, socket)
connection = RedisConnection(
host,
port,
password,
db,
Transport.transport(host, port, sslconfig)
)
on_connect(connection)
catch
throw(ConnectionException("Failed to connect to Redis server"))
end
end

function SentinelConnection(; host="127.0.0.1", port=26379, password="", db=0)
function SentinelConnection(; host="127.0.0.1", port=26379, password="", db=0, sslconfig=nothing)

Check warning on line 64 in src/connection.jl

View check run for this annotation

Codecov / codecov/patch

src/connection.jl#L64

Added line #L64 was not covered by tests
try
socket = connect(host, port)
sentinel_connection = SentinelConnection(host, port, password, db, socket)
sentinel_connection = SentinelConnection(

Check warning on line 66 in src/connection.jl

View check run for this annotation

Codecov / codecov/patch

src/connection.jl#L66

Added line #L66 was not covered by tests
host,
port,
password,
db,
Transport.transport(host, port, sslconfig)
)
on_connect(sentinel_connection)
catch
throw(ConnectionException("Failed to connect to Redis sentinel"))
end
end

function TransactionConnection(parent::RedisConnection)
function TransactionConnection(parent::RedisConnection; sslconfig=Transport.get_sslconfig(parent))
try
socket = connect(parent.host, parent.port)
transaction_connection = TransactionConnection(parent.host,
parent.port, parent.password, parent.db, socket)
transaction_connection = TransactionConnection(
parent.host,
parent.port,
parent.password,
parent.db,
Transport.transport(parent.host, parent.port, sslconfig)
)
on_connect(transaction_connection)
catch
throw(ConnectionException("Failed to create transaction"))
end
end

function PipelineConnection(parent::RedisConnection)
function PipelineConnection(parent::RedisConnection; sslconfig=Transport.get_sslconfig(parent))
try
socket = connect(parent.host, parent.port)
pipeline_connection = PipelineConnection(parent.host,
parent.port, parent.password, parent.db, socket, 0)
pipeline_connection = PipelineConnection(
parent.host,
parent.port,
parent.password,
parent.db,
Transport.transport(parent.host, parent.port, sslconfig),
0
)
on_connect(pipeline_connection)
catch
throw(ConnectionException("Failed to create pipeline"))
end
end

function SubscriptionConnection(parent::SubscribableConnection)
function SubscriptionConnection(parent::SubscribableConnection; sslconfig=Transport.get_sslconfig(parent))
try
socket = connect(parent.host, parent.port)
subscription_connection = SubscriptionConnection(parent.host,
parent.port, parent.password, parent.db, Dict{AbstractString, Function}(),
Dict{AbstractString, Function}(), socket)
subscription_connection = SubscriptionConnection(
parent.host,
parent.port,
parent.password,
parent.db,
Dict{AbstractString, Function}(),
Dict{AbstractString, Function}(),
Transport.transport(parent.host, parent.port, sslconfig)
)
on_connect(subscription_connection)
catch
throw(ConnectionException("Failed to create subscription"))
end
end

function on_connect(conn::RedisConnectionBase)
# disable nagle and enable quickack to speed up the usually small exchanges
Sockets.nagle(conn.socket, false)
Sockets.quickack(conn.socket, true)

Transport.set_props!(conn.transport)
conn.password != "" && auth(conn, conn.password)
conn.db != 0 && select(conn, conn.db)
conn
end

function disconnect(conn::RedisConnectionBase)
close(conn.socket)
Transport.close(conn.transport)
end

function is_connected(conn::RedisConnectionBase)
conn.socket.status == StatusActive || conn.socket.status == StatusOpen || conn.socket.status == StatusPaused
Transport.is_connected(conn.transport)
end
28 changes: 14 additions & 14 deletions src/parser.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Formatting of incoming Redis Replies
"""
function getline(s::TCPSocket)
l = chomp(readline(s))
function getline(t::Transport.RedisTransport)
l = chomp(Transport.read_line(t))
length(l) > 1 || throw(ProtocolException("Invalid response received: $l"))
return l
end
Expand All @@ -12,15 +12,15 @@ convert_reply(reply::Array) = [convert_reply(r) for r in reply]
convert_reply(x) = x

function read_reply(conn::RedisConnectionBase)
l = getline(conn.socket)
reply = parseline(l, conn.socket)
l = getline(conn.transport)
reply = parseline(l, conn.transport)
convert_reply(reply)
end

parse_error(l::AbstractString) = throw(ServerException(l))

function parse_bulk_string(s::TCPSocket, slen::Int)
b = read(s, slen+2) # add crlf
function parse_bulk_string(t::Transport.RedisTransport, slen::Int)
b = Transport.read_nbytes(t, slen+2) # add crlf
if length(b) != slen + 2
throw(ProtocolException(
"Bulk string read error: expected $slen bytes; received $(length(b))"
Expand All @@ -30,17 +30,17 @@ function parse_bulk_string(s::TCPSocket, slen::Int)
end
end

function parse_array(s::TCPSocket, slen::Int)
function parse_array(t::Transport.RedisTransport, slen::Int)
a = Array{Any, 1}(undef, slen)
for i = 1:slen
l = getline(s)
r = parseline(l, s)
l = getline(t)
r = parseline(l, t)
a[i] = r
end
return a
end

function parseline(l::AbstractString, s::TCPSocket)
function parseline(l::AbstractString, t::Transport.RedisTransport)
reply_type = l[1]
reply_token = l[2:end]
if reply_type == '+'
Expand All @@ -52,14 +52,14 @@ function parseline(l::AbstractString, s::TCPSocket)
if slen == -1
nothing
else
parse_bulk_string(s, slen)
parse_bulk_string(t, slen)
end
elseif reply_type == '*'
slen = parse(Int, reply_token)
if slen == -1
nothing
else
parse_array(s, slen)
parse_array(t, slen)
end
elseif reply_type == '-'
parse_error(reply_token)
Expand Down Expand Up @@ -90,8 +90,8 @@ function execute_command_without_reply(conn::RedisConnectionBase, command)
is_connected(conn) || throw(ConnectionException("Socket is disconnected"))
iob = IOBuffer()
pack_command(iob, command)
lock(conn.socket.lock) do
write(conn.socket, take!(iob))
Transport.io_lock(conn.transport) do
Transport.write_bytes(conn.transport, take!(iob))
end
end

Expand Down
19 changes: 19 additions & 0 deletions src/transport/tcp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
struct TCPTransport <: RedisTransport
sock::TCPSocket
end

read_line(t::TCPTransport) = readline(t.sock)
read_nbytes(t::TCPTransport, m::Int) = read(t.sock, m)
write_bytes(t::TCPTransport, b::Vector{UInt8}) = write(t.sock, b)
Base.close(t::TCPTransport) = close(t.sock)
function set_props!(t::TCPTransport)
# disable nagle and enable quickack to speed up the usually small exchanges
Sockets.nagle(t.sock, false)
Sockets.quickack(t.sock, true)
end
get_sslconfig(::TCPTransport) = nothing
io_lock(f, t::TCPTransport) = lock(f, t.sock.lock)
function is_connected(t::TCPTransport)
status = t.sock.status
status == StatusActive || status == StatusOpen || status == StatusPaused
end
56 changes: 56 additions & 0 deletions src/transport/tls.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
struct TLSTransport <: RedisTransport
sock::TCPSocket
ctx::MbedTLS.SSLContext
sslconfig::MbedTLS.SSLConfig
buff::IOBuffer

function TLSTransport(sock::TCPSocket, sslconfig::MbedTLS.SSLConfig)
ctx = MbedTLS.SSLContext()
MbedTLS.setup!(ctx, sslconfig)
MbedTLS.associate!(ctx, sock)
MbedTLS.handshake(ctx)

Check warning on line 11 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L7-L11

Added lines #L7 - L11 were not covered by tests

return new(sock, ctx, sslconfig, PipeBuffer())

Check warning on line 13 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L13

Added line #L13 was not covered by tests
end
end

function read_into_buffer_until(cond::Function, t::TLSTransport)
cond(t) && return

Check warning on line 18 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L17-L18

Added lines #L17 - L18 were not covered by tests

buff = Vector{UInt8}(undef, MbedTLS.MBEDTLS_SSL_MAX_CONTENT_LEN)
pbuff = pointer(buff)

Check warning on line 21 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L20-L21

Added lines #L20 - L21 were not covered by tests

while !cond(t) && !eof(t.ctx)
nread = readbytes!(t.ctx, buff; all=false)
if nread > 0
unsafe_write(t.buff, pbuff, nread)

Check warning on line 26 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L23-L26

Added lines #L23 - L26 were not covered by tests
end
end

Check warning on line 28 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L28

Added line #L28 was not covered by tests
end

function read_line(t::TLSTransport)
read_into_buffer_until(t) do t
iob = t.buff
(bytesavailable(t.buff) > 0) && (UInt8('\n') in view(iob.data, iob.ptr:iob.size))

Check warning on line 34 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L31-L34

Added lines #L31 - L34 were not covered by tests
end
return readline(t.buff)

Check warning on line 36 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L36

Added line #L36 was not covered by tests
end
function read_nbytes(t::TLSTransport, m::Int)
read_into_buffer_until(t) do t
bytesavailable(t.buff) >= m

Check warning on line 40 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L38-L40

Added lines #L38 - L40 were not covered by tests
end
return read(t.buff, m)

Check warning on line 42 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L42

Added line #L42 was not covered by tests
end
write_bytes(t::TLSTransport, b::Vector{UInt8}) = write(t.ctx, b)
Base.close(t::TLSTransport) = close(t.ctx)
function set_props!(s::TLSTransport)

Check warning on line 46 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L44-L46

Added lines #L44 - L46 were not covered by tests
# disable nagle and enable quickack to speed up the usually small exchanges
Sockets.nagle(s.sock, false)
Sockets.quickack(s.sock, true)

Check warning on line 49 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L48-L49

Added lines #L48 - L49 were not covered by tests
end
get_sslconfig(t::TLSTransport) = t.sslconfig
io_lock(f, t::TLSTransport) = lock(f, t.sock.lock)
function is_connected(t::TLSTransport)
status = t.sock.status
status == StatusActive || status == StatusOpen || status == StatusPaused

Check warning on line 55 in src/transport/tls.jl

View check run for this annotation

Codecov / codecov/patch

src/transport/tls.jl#L51-L55

Added lines #L51 - L55 were not covered by tests
end
Loading

0 comments on commit a80f2ad

Please sign in to comment.