Skip to content

Commit

Permalink
add encryption to communication
Browse files Browse the repository at this point in the history
  • Loading branch information
zachmann committed Sep 29, 2020
1 parent 1eccc83 commit 68f2e37
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 55 deletions.
178 changes: 178 additions & 0 deletions liboidcagent/comm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package liboidcagent

import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"os"
"strings"

"golang.org/x/crypto/nacl/box"
)

func communicateWithSock(c net.Conn, request string) (response []byte, err error) {
_, err = c.Write([]byte(request))
if err != nil {
err = fmt.Errorf("Writing to socket: %s", err)
return
}
for {
buffer := make([]byte, 4096)
n, e := c.Read(buffer)
response = append(response, buffer[:n]...)
if n < 4096 || e != nil {
err = e
break
}
}
if err != nil {
err = fmt.Errorf("Reading from socket: %s", err)
}
return
}

func initCommunication(remote bool) (c net.Conn, err error) {
envVar := "OIDC_SOCK"
sockType := "unix"
if remote {
envVar = "OIDC_REMOTE_SOCK"
sockType = "tcp"
}
socketValue, socketSet := os.LookupEnv(envVar)
if !socketSet {
err = fmt.Errorf("$%s not set", envVar)
return
}

if remote {
if _, port, _ := net.SplitHostPort(socketValue); port == "" {
socketValue = fmt.Sprintf("%s:%d", socketValue, 42424)
}
}

c, err = net.Dial(sockType, socketValue)
if err != nil {
err = fmt.Errorf("Dialing socket: %s", err)
return
}
return
}

func communicatePlain(remote bool, request string) (response string, err error) {
c, err := initCommunication(remote)
if err != nil {
return
}
defer c.Close()

res, err := communicateWithSock(c, request)
response = string(res)
return
}

func communicateEncrypted(remote bool, request string) (response string, err error) {
c, err := initCommunication(remote)
if err != nil {
return
}
defer c.Close()

clientPrivateKey, _, serverPublicKey, err := initKeys(c)
if err != nil {
return
}

encryptedMsg, err := encryptMessage(request, serverPublicKey, clientPrivateKey)
if err != nil {
return
}

encryptedResponse, err := communicateWithSock(c, encryptedMsg)
if err != nil {
return
}
encryptedResponseStr := string(encryptedResponse)
if isJSON(encryptedResponseStr) {
// response not encrypted
response = encryptedResponseStr
return
}

response, err = decryptMessage(encryptedResponseStr, serverPublicKey, clientPrivateKey)
return
}

func initKeys(c net.Conn) (clientPrivateKey, clientPublicKey, serverPublicKey *[32]byte, err error) {
clientPublicKey, clientPrivateKey, err = box.GenerateKey(rand.Reader)
if err != nil {
return
}
clientPubKeyBase64 := base64.StdEncoding.EncodeToString(clientPublicKey[:])
serverPubKeyBase64, err := communicateWithSock(c, clientPubKeyBase64)
if err != nil {
return
}
serverPubKeyB, err := decodeBytes(serverPubKeyBase64)
if err != nil {
return
}
serverPublicKey = sliceToArray32(serverPubKeyB)
return
}

func encryptMessage(message string, serverPublicKey, clientPrivateKey *[32]byte) (string, error) {
var nonce [24]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
return "", err
}
encrypted := box.Seal([]byte{}, []byte(message), &nonce, serverPublicKey, clientPrivateKey)
encryptedBase64 := base64.StdEncoding.EncodeToString(encrypted)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce[:])
msgLen := len(message)
encryptedMsg := fmt.Sprintf("%d:%s:%s", msgLen, nonceBase64, encryptedBase64)
return encryptedMsg, nil
}

func decryptMessage(message string, serverPublicKey, clientPrivateKey *[32]byte) (decrypted string, err error) {
split := strings.Split(message, ":")
nonce, err := base64.StdEncoding.DecodeString(split[1])
if err != nil {
return
}
encryptedRes, err := base64.StdEncoding.DecodeString(split[2])
if err != nil {
return
}
res, ok := box.Open([]byte{}, encryptedRes, sliceToArray24(nonce), serverPublicKey, clientPrivateKey)
decrypted = string(res)
if !ok {
err = fmt.Errorf("Decryption error")
}
return
}

func sliceToArray32(slice []byte) *[32]byte {
arr := [32]byte{}
copy(arr[:], slice)
return &arr
}

func sliceToArray24(slice []byte) *[24]byte {
arr := [24]byte{}
copy(arr[:], slice)
return &arr
}

func decodeBytes(src []byte) ([]byte, error) {
out := make([]byte, base64.StdEncoding.DecodedLen(len(src)))
n, err := base64.StdEncoding.Decode(out, src)
return out[:n], err
}

func isJSON(s string) bool {
var js map[string]interface{}
return json.Unmarshal([]byte(s), &js) == nil
}
90 changes: 35 additions & 55 deletions liboidcagent/liboidcagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@ package liboidcagent

import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"os"
"time"
)

Expand Down Expand Up @@ -52,45 +48,23 @@ func createTokenRequestIssuer(issuer string, minValidPeriod uint64, scope string
return createTokenRequest(requestPartIss, minValidPeriod, scope, applicationHint, audience)
}

func communicateWithSock(request string) (response []byte, err error) {
socketValue, socketSet := os.LookupEnv("OIDC_SOCK")
if !socketSet {
err = errors.New("$OIDC_SOCK not set")
return
func parseIpcResponse(remote bool, response []byte) (tokenResponse TokenResponse, err error) {
rem := ""
if remote {
rem = "Remote "
}

c, err := net.Dial("unix", socketValue)
if err != nil {
err = fmt.Errorf("Dialing socket: %s", err)
return
}
defer c.Close()

_, err = c.Write([]byte(request))
if err != nil {
err = fmt.Errorf("Writing to socket: %s", err)
return
}
response, err = ioutil.ReadAll(c)
if err != nil {
err = fmt.Errorf("Reading from socket: %s", err)
}
return
}

func parseIpcResponse(response []byte) (tokenResponse TokenResponse, err error) {
var res rawTokenResponse
err = json.Unmarshal(response, &res)
if err != nil {
err = fmt.Errorf("Unable to unmarshal: %s", response)
return
}
if res.Error != "" {
err = fmt.Errorf("Agent error: %s", res.Error)
err = fmt.Errorf("%sAgent error: %s", rem, res.Error)
return
}
if res.Status == "failure" {
err = fmt.Errorf("status is \"failure\"")
err = fmt.Errorf("%sstatus is \"failure\"", rem)
return
}
tokenResponse = TokenResponse{
Expand All @@ -106,23 +80,38 @@ func parseIpcResponse(response []byte) (tokenResponse TokenResponse, err error)
// Deprecated: GetTokenResponse is deprecated and only exists for compatibility
// reasons. New applications should use GetTokenResponse2 instead.
func GetTokenResponse(accountname string, minValidPeriod uint64, scope, applicationHint string) (resp TokenResponse, err error) {
ipcReq := createTokenRequestAccount(accountname, minValidPeriod, scope, applicationHint, "")
ipcResponse, err := communicateWithSock(ipcReq)
if err != nil {
return
}
resp, err = parseIpcResponse(ipcResponse)
return
return GetTokenResponse2(accountname, minValidPeriod, scope, applicationHint, "")
}

// GetTokenResponse2 gets a token response by accountname
func GetTokenResponse2(accountname string, minValidPeriod uint64, scope, applicationHint, audience string) (resp TokenResponse, err error) {
ipcReq := createTokenRequestAccount(accountname, minValidPeriod, scope, applicationHint, audience)
ipcResponse, err := communicateWithSock(ipcReq)
ipcResponse, err := communicateEncrypted(false, ipcReq)
if err != nil {
if err.Error() != "$OIDC_SOCK not set" {
return
}
ipcResponse, err = communicateEncrypted(true, ipcReq)
if err != nil {
err = fmt.Errorf("$OIDC_SOCK not set and %s on remote", err)
return
}
resp, err = parseIpcResponse(true, []byte(ipcResponse))
return
}
resp, err = parseIpcResponse(ipcResponse)
resp, err = parseIpcResponse(false, []byte(ipcResponse))
if err != nil && err.Error() == "Agent error: No account configured with that short name" {
localErr := err
//Try remote
ipcResponse, err = communicateEncrypted(true, ipcReq)
if err != nil {
if err.Error() == "$OIDC_REMOTE_SOCK not set" {
err = localErr
}
return
}
resp, err = parseIpcResponse(true, []byte(ipcResponse))
}
return
}

Expand All @@ -132,25 +121,18 @@ func GetTokenResponse2(accountname string, minValidPeriod uint64, scope, applica
// compatibility reasons. New applications should use
// GetTokenResponseByIssuerURL2 instead.
func GetTokenResponseByIssuerURL(issuer string, minValidPeriod uint64, scope, applicationHint string) (tokenResponse TokenResponse, err error) {
ipcReq := createTokenRequestIssuer(issuer, minValidPeriod, scope, applicationHint, "")
response, err := communicateWithSock(ipcReq)
if err != nil {
err = fmt.Errorf("Communicating with socket: %s", err)
return
}
tokenResponse, err = parseIpcResponse(response)
return
return GetTokenResponseByIssuerURL2(issuer, minValidPeriod, scope, applicationHint, "")
}

// GetTokenResponseByIssuerURL2 gets a token response by issuerURL
func GetTokenResponseByIssuerURL2(issuer string, minValidPeriod uint64, scope, applicationHint, audience string) (tokenResponse TokenResponse, err error) {
ipcReq := createTokenRequestIssuer(issuer, minValidPeriod, scope, applicationHint, audience)
response, err := communicateWithSock(ipcReq)
response, err := communicateEncrypted(false, ipcReq)
if err != nil {
err = fmt.Errorf("Communicating with socket: %s", err)
return
}
tokenResponse, err = parseIpcResponse(response)
tokenResponse, err = parseIpcResponse(false, []byte(response))
return
}

Expand All @@ -159,8 +141,7 @@ func GetTokenResponseByIssuerURL2(issuer string, minValidPeriod uint64, scope, a
// Deprecated: GetAccessToken is deprecated and only exists for compatibility
// reasons. New applications should use GetAccessToken2 instead.
func GetAccessToken(accountname string, minValidPeriod uint64, scope, applicationHint string) (token string, err error) {
tokenResponse, err := GetTokenResponse(accountname, minValidPeriod, scope, applicationHint)
return tokenResponse.Token, err
return GetAccessToken2(accountname, minValidPeriod, scope, applicationHint, "")
}

// GetAccessToken2 gets an access token by accountname
Expand All @@ -174,8 +155,7 @@ func GetAccessToken2(accountname string, minValidPeriod uint64, scope, applicati
// Deprecated: GetAccessTokenByIssuerURL is deprecated and only exists for compatibility
// reasons. New applications should use GetAccessTokenByIssuerURL2 instead.
func GetAccessTokenByIssuerURL(issuerURL string, minValidPeriod uint64, scope, applicationHint string) (token string, err error) {
tokenResponse, err := GetTokenResponseByIssuerURL(issuerURL, minValidPeriod, scope, applicationHint)
return tokenResponse.Token, err
return GetAccessTokenByIssuerURL2(issuerURL, minValidPeriod, scope, applicationHint, "")
}

// GetAccessTokenByIssuerURL2 gets an access token by issuerURL
Expand Down

0 comments on commit 68f2e37

Please sign in to comment.