diff --git a/config/webhook.go b/config/webhook.go index fdd4d29fe..6f9a8ac09 100644 --- a/config/webhook.go +++ b/config/webhook.go @@ -80,7 +80,7 @@ func ExecWebhook(domains *Domains, conf *Config) (v4Status updateStatusType, v6S return } - headers := checkParseHeaders(conf.WebhookHeaders) + headers := extractHeaders(conf.WebhookHeaders) for key, value := range headers { req.Header.Add(key, value) } @@ -144,19 +144,28 @@ func getDomainsStr(domains []*Domain) string { return str } -func checkParseHeaders(headerStr string) (headers map[string]string) { - headers = make(map[string]string) - headerArr := strings.Split(headerStr, "\r\n") - for _, headerStr := range headerArr { - headerStr = strings.TrimSpace(headerStr) - if headerStr != "" { - parts := strings.Split(headerStr, ":") - if len(parts) != 2 { - util.Log("Webhook Header不正确: %s", headerStr) - continue - } - headers[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) +// extractHeaders converts s into a map of headers. +// +// See also: https://github.com/appleboy/gorush/blob/v1.17.0/notify/feedback.go#L15 +func extractHeaders(s string) map[string]string { + lines := util.SplitLines(s) + headers := make(map[string]string, len(lines)) + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + parts := strings.Split(line, ":") + if len(parts) != 2 { + util.Log("Webhook Header不正确: %s", line) + continue } + + k, v := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) + headers[k] = v } + return headers } diff --git a/config/webhook_test.go b/config/webhook_test.go index 0ae467ebd..df77bb723 100644 --- a/config/webhook_test.go +++ b/config/webhook_test.go @@ -1,17 +1,22 @@ package config import ( - "fmt" + "reflect" "testing" ) // TestParseHeaderArr 测试 parseHeaderArr -func TestParseHeaderArr(t *testing.T) { - headers := "a : 1\r\nb:2\r\n" - expected := `map[a:1 b:2]` - parsedHeaders := checkParseHeaders(headers) - resultStr := fmt.Sprintf("%v", parsedHeaders) - if resultStr != expected { - t.Error("解析Header失败", resultStr) +func TestExtractHeaders(t *testing.T) { + input := ` +a: foo +b: bar` + expected := map[string]string{ + "a": "foo", + "b": "bar", + } + + parsedHeaders := extractHeaders(input) + if !reflect.DeepEqual(parsedHeaders, expected) { + t.Errorf("Expected %v, got %v", expected, parsedHeaders) } } diff --git a/util/string.go b/util/string.go index 75c0e8941..5636bc6c4 100644 --- a/util/string.go +++ b/util/string.go @@ -23,3 +23,12 @@ func toHostname(url string) string { return strings.Split(stripped, "/")[0] } + +// SplitLines splits a string into lines by '\r\n' or '\n'. +func SplitLines(s string) []string { + if strings.Contains(s, "\r\n") { + return strings.Split(s, "\r\n") + } + + return strings.Split(s, "\n") +} diff --git a/web/save.go b/web/save.go index 14c8a99e0..83353c6be 100755 --- a/web/save.go +++ b/web/save.go @@ -116,7 +116,7 @@ func checkAndSave(request *http.Request) string { dnsConf.Ipv4.URL = strings.TrimSpace(v.Ipv4Url) dnsConf.Ipv4.NetInterface = v.Ipv4NetInterface dnsConf.Ipv4.Cmd = strings.TrimSpace(v.Ipv4Cmd) - dnsConf.Ipv4.Domains = splitLines(v.Ipv4Domains) + dnsConf.Ipv4.Domains = util.SplitLines(v.Ipv4Domains) dnsConf.Ipv6.Enable = v.Ipv6Enable dnsConf.Ipv6.GetType = v.Ipv6GetType @@ -124,7 +124,7 @@ func checkAndSave(request *http.Request) string { dnsConf.Ipv6.NetInterface = v.Ipv6NetInterface dnsConf.Ipv6.Cmd = strings.TrimSpace(v.Ipv6Cmd) dnsConf.Ipv6.Ipv6Reg = strings.TrimSpace(v.Ipv6Reg) - dnsConf.Ipv6.Domains = splitLines(v.Ipv6Domains) + dnsConf.Ipv6.Domains = util.SplitLines(v.Ipv6Domains) if k < len(conf.DnsConf) { c := &conf.DnsConf[k] @@ -160,12 +160,3 @@ func checkAndSave(request *http.Request) string { } return "ok" } - -// splitLines splits a string into lines by '\r\n' or '\n'. -func splitLines(s string) []string { - if strings.Contains(s, "\r\n") { - return strings.Split(s, "\r\n") - } - - return strings.Split(s, "\n") -}