From 68f2e370f1bdb63615830fb55582d71ee5abf1e0 Mon Sep 17 00:00:00 2001 From: zachmann Date: Mon, 28 Sep 2020 13:41:34 +0200 Subject: [PATCH] add encryption to communication --- liboidcagent/comm.go | 178 +++++++++++++++++++++++++++++++++++ liboidcagent/liboidcagent.go | 90 +++++++----------- 2 files changed, 213 insertions(+), 55 deletions(-) create mode 100644 liboidcagent/comm.go diff --git a/liboidcagent/comm.go b/liboidcagent/comm.go new file mode 100644 index 0000000..5b4582f --- /dev/null +++ b/liboidcagent/comm.go @@ -0,0 +1,178 @@ +package liboidcagent + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net" + "os" + "strings" + + "golang.org/x/crypto/nacl/box" +) + +func communicateWithSock(c net.Conn, request string) (response []byte, err error) { + _, err = c.Write([]byte(request)) + if err != nil { + err = fmt.Errorf("Writing to socket: %s", err) + return + } + for { + buffer := make([]byte, 4096) + n, e := c.Read(buffer) + response = append(response, buffer[:n]...) + if n < 4096 || e != nil { + err = e + break + } + } + if err != nil { + err = fmt.Errorf("Reading from socket: %s", err) + } + return +} + +func initCommunication(remote bool) (c net.Conn, err error) { + envVar := "OIDC_SOCK" + sockType := "unix" + if remote { + envVar = "OIDC_REMOTE_SOCK" + sockType = "tcp" + } + socketValue, socketSet := os.LookupEnv(envVar) + if !socketSet { + err = fmt.Errorf("$%s not set", envVar) + return + } + + if remote { + if _, port, _ := net.SplitHostPort(socketValue); port == "" { + socketValue = fmt.Sprintf("%s:%d", socketValue, 42424) + } + } + + c, err = net.Dial(sockType, socketValue) + if err != nil { + err = fmt.Errorf("Dialing socket: %s", err) + return + } + return +} + +func communicatePlain(remote bool, request string) (response string, err error) { + c, err := initCommunication(remote) + if err != nil { + return + } + defer c.Close() + + res, err := communicateWithSock(c, request) + response = string(res) + return +} + +func communicateEncrypted(remote bool, request string) (response string, err error) { + c, err := initCommunication(remote) + if err != nil { + return + } + defer c.Close() + + clientPrivateKey, _, serverPublicKey, err := initKeys(c) + if err != nil { + return + } + + encryptedMsg, err := encryptMessage(request, serverPublicKey, clientPrivateKey) + if err != nil { + return + } + + encryptedResponse, err := communicateWithSock(c, encryptedMsg) + if err != nil { + return + } + encryptedResponseStr := string(encryptedResponse) + if isJSON(encryptedResponseStr) { + // response not encrypted + response = encryptedResponseStr + return + } + + response, err = decryptMessage(encryptedResponseStr, serverPublicKey, clientPrivateKey) + return +} + +func initKeys(c net.Conn) (clientPrivateKey, clientPublicKey, serverPublicKey *[32]byte, err error) { + clientPublicKey, clientPrivateKey, err = box.GenerateKey(rand.Reader) + if err != nil { + return + } + clientPubKeyBase64 := base64.StdEncoding.EncodeToString(clientPublicKey[:]) + serverPubKeyBase64, err := communicateWithSock(c, clientPubKeyBase64) + if err != nil { + return + } + serverPubKeyB, err := decodeBytes(serverPubKeyBase64) + if err != nil { + return + } + serverPublicKey = sliceToArray32(serverPubKeyB) + return +} + +func encryptMessage(message string, serverPublicKey, clientPrivateKey *[32]byte) (string, error) { + var nonce [24]byte + if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { + return "", err + } + encrypted := box.Seal([]byte{}, []byte(message), &nonce, serverPublicKey, clientPrivateKey) + encryptedBase64 := base64.StdEncoding.EncodeToString(encrypted) + nonceBase64 := base64.StdEncoding.EncodeToString(nonce[:]) + msgLen := len(message) + encryptedMsg := fmt.Sprintf("%d:%s:%s", msgLen, nonceBase64, encryptedBase64) + return encryptedMsg, nil +} + +func decryptMessage(message string, serverPublicKey, clientPrivateKey *[32]byte) (decrypted string, err error) { + split := strings.Split(message, ":") + nonce, err := base64.StdEncoding.DecodeString(split[1]) + if err != nil { + return + } + encryptedRes, err := base64.StdEncoding.DecodeString(split[2]) + if err != nil { + return + } + res, ok := box.Open([]byte{}, encryptedRes, sliceToArray24(nonce), serverPublicKey, clientPrivateKey) + decrypted = string(res) + if !ok { + err = fmt.Errorf("Decryption error") + } + return +} + +func sliceToArray32(slice []byte) *[32]byte { + arr := [32]byte{} + copy(arr[:], slice) + return &arr +} + +func sliceToArray24(slice []byte) *[24]byte { + arr := [24]byte{} + copy(arr[:], slice) + return &arr +} + +func decodeBytes(src []byte) ([]byte, error) { + out := make([]byte, base64.StdEncoding.DecodedLen(len(src))) + n, err := base64.StdEncoding.Decode(out, src) + return out[:n], err +} + +func isJSON(s string) bool { + var js map[string]interface{} + return json.Unmarshal([]byte(s), &js) == nil +} diff --git a/liboidcagent/liboidcagent.go b/liboidcagent/liboidcagent.go index cd2f8e5..b1e6047 100644 --- a/liboidcagent/liboidcagent.go +++ b/liboidcagent/liboidcagent.go @@ -2,11 +2,7 @@ package liboidcagent import ( "encoding/json" - "errors" "fmt" - "io/ioutil" - "net" - "os" "time" ) @@ -52,33 +48,11 @@ func createTokenRequestIssuer(issuer string, minValidPeriod uint64, scope string return createTokenRequest(requestPartIss, minValidPeriod, scope, applicationHint, audience) } -func communicateWithSock(request string) (response []byte, err error) { - socketValue, socketSet := os.LookupEnv("OIDC_SOCK") - if !socketSet { - err = errors.New("$OIDC_SOCK not set") - return +func parseIpcResponse(remote bool, response []byte) (tokenResponse TokenResponse, err error) { + rem := "" + if remote { + rem = "Remote " } - - c, err := net.Dial("unix", socketValue) - if err != nil { - err = fmt.Errorf("Dialing socket: %s", err) - return - } - defer c.Close() - - _, err = c.Write([]byte(request)) - if err != nil { - err = fmt.Errorf("Writing to socket: %s", err) - return - } - response, err = ioutil.ReadAll(c) - if err != nil { - err = fmt.Errorf("Reading from socket: %s", err) - } - return -} - -func parseIpcResponse(response []byte) (tokenResponse TokenResponse, err error) { var res rawTokenResponse err = json.Unmarshal(response, &res) if err != nil { @@ -86,11 +60,11 @@ func parseIpcResponse(response []byte) (tokenResponse TokenResponse, err error) return } if res.Error != "" { - err = fmt.Errorf("Agent error: %s", res.Error) + err = fmt.Errorf("%sAgent error: %s", rem, res.Error) return } if res.Status == "failure" { - err = fmt.Errorf("status is \"failure\"") + err = fmt.Errorf("%sstatus is \"failure\"", rem) return } tokenResponse = TokenResponse{ @@ -106,23 +80,38 @@ func parseIpcResponse(response []byte) (tokenResponse TokenResponse, err error) // Deprecated: GetTokenResponse is deprecated and only exists for compatibility // reasons. New applications should use GetTokenResponse2 instead. func GetTokenResponse(accountname string, minValidPeriod uint64, scope, applicationHint string) (resp TokenResponse, err error) { - ipcReq := createTokenRequestAccount(accountname, minValidPeriod, scope, applicationHint, "") - ipcResponse, err := communicateWithSock(ipcReq) - if err != nil { - return - } - resp, err = parseIpcResponse(ipcResponse) - return + return GetTokenResponse2(accountname, minValidPeriod, scope, applicationHint, "") } // GetTokenResponse2 gets a token response by accountname func GetTokenResponse2(accountname string, minValidPeriod uint64, scope, applicationHint, audience string) (resp TokenResponse, err error) { ipcReq := createTokenRequestAccount(accountname, minValidPeriod, scope, applicationHint, audience) - ipcResponse, err := communicateWithSock(ipcReq) + ipcResponse, err := communicateEncrypted(false, ipcReq) if err != nil { + if err.Error() != "$OIDC_SOCK not set" { + return + } + ipcResponse, err = communicateEncrypted(true, ipcReq) + if err != nil { + err = fmt.Errorf("$OIDC_SOCK not set and %s on remote", err) + return + } + resp, err = parseIpcResponse(true, []byte(ipcResponse)) return } - resp, err = parseIpcResponse(ipcResponse) + resp, err = parseIpcResponse(false, []byte(ipcResponse)) + if err != nil && err.Error() == "Agent error: No account configured with that short name" { + localErr := err + //Try remote + ipcResponse, err = communicateEncrypted(true, ipcReq) + if err != nil { + if err.Error() == "$OIDC_REMOTE_SOCK not set" { + err = localErr + } + return + } + resp, err = parseIpcResponse(true, []byte(ipcResponse)) + } return } @@ -132,25 +121,18 @@ func GetTokenResponse2(accountname string, minValidPeriod uint64, scope, applica // compatibility reasons. New applications should use // GetTokenResponseByIssuerURL2 instead. func GetTokenResponseByIssuerURL(issuer string, minValidPeriod uint64, scope, applicationHint string) (tokenResponse TokenResponse, err error) { - ipcReq := createTokenRequestIssuer(issuer, minValidPeriod, scope, applicationHint, "") - response, err := communicateWithSock(ipcReq) - if err != nil { - err = fmt.Errorf("Communicating with socket: %s", err) - return - } - tokenResponse, err = parseIpcResponse(response) - return + return GetTokenResponseByIssuerURL2(issuer, minValidPeriod, scope, applicationHint, "") } // GetTokenResponseByIssuerURL2 gets a token response by issuerURL func GetTokenResponseByIssuerURL2(issuer string, minValidPeriod uint64, scope, applicationHint, audience string) (tokenResponse TokenResponse, err error) { ipcReq := createTokenRequestIssuer(issuer, minValidPeriod, scope, applicationHint, audience) - response, err := communicateWithSock(ipcReq) + response, err := communicateEncrypted(false, ipcReq) if err != nil { err = fmt.Errorf("Communicating with socket: %s", err) return } - tokenResponse, err = parseIpcResponse(response) + tokenResponse, err = parseIpcResponse(false, []byte(response)) return } @@ -159,8 +141,7 @@ func GetTokenResponseByIssuerURL2(issuer string, minValidPeriod uint64, scope, a // Deprecated: GetAccessToken is deprecated and only exists for compatibility // reasons. New applications should use GetAccessToken2 instead. func GetAccessToken(accountname string, minValidPeriod uint64, scope, applicationHint string) (token string, err error) { - tokenResponse, err := GetTokenResponse(accountname, minValidPeriod, scope, applicationHint) - return tokenResponse.Token, err + return GetAccessToken2(accountname, minValidPeriod, scope, applicationHint, "") } // GetAccessToken2 gets an access token by accountname @@ -174,8 +155,7 @@ func GetAccessToken2(accountname string, minValidPeriod uint64, scope, applicati // Deprecated: GetAccessTokenByIssuerURL is deprecated and only exists for compatibility // reasons. New applications should use GetAccessTokenByIssuerURL2 instead. func GetAccessTokenByIssuerURL(issuerURL string, minValidPeriod uint64, scope, applicationHint string) (token string, err error) { - tokenResponse, err := GetTokenResponseByIssuerURL(issuerURL, minValidPeriod, scope, applicationHint) - return tokenResponse.Token, err + return GetAccessTokenByIssuerURL2(issuerURL, minValidPeriod, scope, applicationHint, "") } // GetAccessTokenByIssuerURL2 gets an access token by issuerURL