diff --git a/cmd/network.go b/cmd/network.go index 286abb5..6a2d547 100644 --- a/cmd/network.go +++ b/cmd/network.go @@ -21,6 +21,7 @@ import ( "bytes" "crypto/hmac" "crypto/sha1" + "crypto/sha512" "encoding/base64" "encoding/json" "errors" @@ -218,6 +219,195 @@ func encodeRequestParams(params url.Values) string { return buf.String() } +func cloneRequestParams(params url.Values) url.Values { + cloned := make(url.Values) + for key, values := range params { + for _, value := range values { + cloned.Add(key, value) + } + } + return cloned +} + +func buildAPIRequestParams(r *Request, api string, args []string) url.Values { + params := make(url.Values) + params.Add("command", api) + apiData := r.Config.GetCache()[api] + for _, arg := range args { + if apiData != nil { + skip := false + for _, fakeArg := range apiData.FakeArgs { + if strings.HasPrefix(arg, fakeArg) { + skip = true + break + } + } + if skip { + continue + } + + } + parts := strings.SplitN(arg, "=", 2) + if len(parts) == 2 { + key := parts[0] + value := parts[1] + if strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"") { + value = value[1 : len(value)-1] + } + if strings.HasPrefix(value, "@") { + possibleFileName := value[1:] + if fileInfo, err := os.Stat(possibleFileName); err == nil && !fileInfo.IsDir() { + bytes, err := ioutil.ReadFile(possibleFileName) + config.Debug() + if err == nil { + value = string(bytes) + config.Debug("Content for argument ", key, " read from file: ", possibleFileName, " is: ", value) + } + } + } + params.Add(key, value) + } + } + signatureversion := "3" + expiresKey := "expires" + params.Add("response", "json") + params.Add("signatureversion", signatureversion) + params.Add(expiresKey, time.Now().UTC().Add(15*time.Minute).Format(time.RFC3339)) + return params +} + +func signRequest(unsignedRequest, secretKey, algorithm string) (string, error) { + signatureAlgorithm, err := config.NormalizeSignatureAlgorithm(algorithm) + if err != nil { + return "", err + } + + var signature []byte + switch signatureAlgorithm { + case config.SignatureAlgorithmHmacSHA1: + mac := hmac.New(sha1.New, []byte(secretKey)) + mac.Write([]byte(strings.ToLower(unsignedRequest))) + signature = mac.Sum(nil) + case config.SignatureAlgorithmHmacSHA512: + mac := hmac.New(sha512.New, []byte(secretKey)) + mac.Write([]byte(strings.ToLower(unsignedRequest))) + signature = mac.Sum(nil) + default: + return "", errors.New("signature algorithm must be concrete") + } + return base64.StdEncoding.EncodeToString(signature), nil +} + +func executeSignedAPIRequest(r *Request, unsignedParams url.Values, algorithm string) (*http.Response, error) { + params := cloneRequestParams(unsignedParams) + encodedParams := encodeRequestParams(params) + + signature, err := signRequest(encodedParams, r.Config.ActiveProfile.SecretKey, algorithm) + if err != nil { + return nil, err + } + if r.Config.Core.PostRequest { + params.Add("signature", signature) + } else { + encodedParams = encodedParams + fmt.Sprintf("&signature=%s", url.QueryEscape(signature)) + params = nil + } + + requestURL := fmt.Sprintf("%s?%s", r.Config.ActiveProfile.URL, encodedParams) + config.Debug("NewAPIRequest API request URL:", requestURL) + return executeRequest(r, requestURL, params) +} + +func parseAPIResponse(body []byte) (map[string]interface{}, error) { + var data map[string]interface{} + if err := json.Unmarshal(body, &data); err != nil { + return nil, errors.New("failed to decode response") + } + + if apiResponse := getResponseData(data); apiResponse != nil { + if _, ok := apiResponse["errorcode"]; ok { + return nil, fmt.Errorf("(HTTP %v, error code %v) %v", apiResponse["errorcode"], apiResponse["cserrorcode"], apiResponse["errortext"]) + } + return apiResponse, nil + } + + return nil, errors.New("failed to decode response") +} + +func isAuthenticationFailure(statusCode int, err error) bool { + if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { + return true + } + if err == nil { + return false + } + errText := strings.ToLower(err.Error()) + for _, marker := range []string{"signature", "authenticate", "authentication", "credential", "unauthoriz", "api key", "apikey"} { + if strings.Contains(errText, marker) { + return true + } + } + return false +} + +func persistDetectedSignatureAlgorithm(r *Request, algorithm string) { + r.Config.ActiveProfile.SignatureAlgorithm = algorithm + if r.CredentialsSupplied { + config.Debug("Credentials supplied on command-line, not persisting detected signature algorithm") + return + } + r.Config.UpdateConfig("signaturealgorithm", algorithm, true) +} + +func detectSignatureAlgorithm(r *Request) (string, error) { + attempts := []string{config.SignatureAlgorithmHmacSHA512, config.SignatureAlgorithmHmacSHA1} + var lastErr error + + for _, algorithm := range attempts { + config.Debug("Trying API signature algorithm probe:", algorithm) + params := buildAPIRequestParams(r, "listApis", []string{"listall=true"}) + params.Add("apiKey", r.Config.ActiveProfile.APIKey) + response, err := executeSignedAPIRequest(r, params, algorithm) + if err != nil { + config.Debug("API signature algorithm probe failed before response for ", algorithm, ": ", err) + lastErr = err + continue + } + + body, _ := ioutil.ReadAll(response.Body) + config.Debug("Signature algorithm probe response body:", string(body)) + if _, err := parseAPIResponse(body); err == nil { + config.Debug("Selected API signature algorithm:", algorithm) + persistDetectedSignatureAlgorithm(r, algorithm) + return algorithm, nil + } else { + lastErr = err + if isAuthenticationFailure(response.StatusCode, err) { + config.Debug("API signature algorithm probe failed authentication for ", algorithm, ": ", err) + } else { + config.Debug("API signature algorithm probe failed with non-authentication error for ", algorithm, ": ", err) + } + } + } + + config.Debug("Signature algorithm autodetection failed; attempted algorithms:", strings.Join(attempts, ", ")) + if lastErr != nil { + return "", lastErr + } + return "", errors.New("failed to detect signature algorithm") +} + +func activeSignatureAlgorithm(r *Request) (string, error) { + signatureAlgorithm, err := config.NormalizeSignatureAlgorithm(r.Config.ActiveProfile.SignatureAlgorithm) + if err != nil { + return "", err + } + if signatureAlgorithm == config.SignatureAlgorithmAuto { + return detectSignatureAlgorithm(r) + } + return signatureAlgorithm, nil +} + func getResponseData(data map[string]interface{}) map[string]interface{} { for k := range data { if strings.HasSuffix(k, "response") { @@ -276,72 +466,28 @@ func pollAsyncJob(r *Request, jobID string) (map[string]interface{}, error) { // NewAPIRequest makes an API request to configured management server func NewAPIRequest(r *Request, api string, args []string, isAsync bool) (map[string]interface{}, error) { - params := make(url.Values) - params.Add("command", api) - apiData := r.Config.GetCache()[api] - for _, arg := range args { - if apiData != nil { - skip := false - for _, fakeArg := range apiData.FakeArgs { - if strings.HasPrefix(arg, fakeArg) { - skip = true - break - } - } - if skip { - continue - } - - } - parts := strings.SplitN(arg, "=", 2) - if len(parts) == 2 { - key := parts[0] - value := parts[1] - if strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"") { - value = value[1 : len(value)-1] - } - if strings.HasPrefix(value, "@") { - possibleFileName := value[1:] - if fileInfo, err := os.Stat(possibleFileName); err == nil && !fileInfo.IsDir() { - bytes, err := ioutil.ReadFile(possibleFileName) - config.Debug() - if err == nil { - value = string(bytes) - config.Debug("Content for argument ", key, " read from file: ", possibleFileName, " is: ", value) - } - } - } - params.Add(key, value) - } - } - signatureversion := "3" - expiresKey := "expires" - params.Add("response", "json") - params.Add("signatureversion", signatureversion) - params.Add(expiresKey, time.Now().UTC().Add(15*time.Minute).Format(time.RFC3339)) + params := buildAPIRequestParams(r, api, args) var encodedParams string var err error + usingSessionAuth := false if len(r.Config.ActiveProfile.APIKey) > 0 && len(r.Config.ActiveProfile.SecretKey) > 0 { apiKey := r.Config.ActiveProfile.APIKey - secretKey := r.Config.ActiveProfile.SecretKey - if len(apiKey) > 0 { params.Add("apiKey", apiKey) } - encodedParams = encodeRequestParams(params) - - mac := hmac.New(sha1.New, []byte(secretKey)) - mac.Write([]byte(strings.ToLower(encodedParams))) - signature := base64.StdEncoding.EncodeToString(mac.Sum(nil)) - if r.Config.Core.PostRequest { - params.Add("signature", signature) - } else { - encodedParams = encodedParams + fmt.Sprintf("&signature=%s", url.QueryEscape(signature)) - params = nil + signatureAlgorithm, err := activeSignatureAlgorithm(r) + if err != nil { + return nil, err } + response, err := executeSignedAPIRequest(r, params, signatureAlgorithm) + if err != nil { + return nil, err + } + return processAPIResponse(r, response, isAsync) } else if len(r.Config.ActiveProfile.Username) > 0 && len(r.Config.ActiveProfile.Password) > 0 { + usingSessionAuth = true sessionKey, err := Login(r) if err != nil { return nil, err @@ -367,7 +513,7 @@ func NewAPIRequest(r *Request, api string, args []string, isAsync bool) (map[str config.Debug("Credentials supplied on command-line, not falling back to login") } - if response.StatusCode == http.StatusUnauthorized && !r.CredentialsSupplied { + if usingSessionAuth && response.StatusCode == http.StatusUnauthorized && !r.CredentialsSupplied { r.Client().Jar, _ = cookiejar.New(nil) sessionKey, err := Login(r) if err != nil { @@ -384,27 +530,26 @@ func NewAPIRequest(r *Request, api string, args []string, isAsync bool) (map[str } } + return processAPIResponse(r, response, isAsync) +} + +func processAPIResponse(r *Request, response *http.Response, isAsync bool) (map[string]interface{}, error) { body, _ := ioutil.ReadAll(response.Body) config.Debug("NewAPIRequest response body:", string(body)) - var data map[string]interface{} - _ = json.Unmarshal([]byte(body), &data) + apiResponse, err := parseAPIResponse(body) + if err != nil { + return nil, err + } if isAsync && r.Config.Core.AsyncBlock { - if jobResponse := getResponseData(data); jobResponse != nil && jobResponse["jobid"] != nil { - jobID := jobResponse["jobid"].(string) + if apiResponse["jobid"] != nil { + jobID := apiResponse["jobid"].(string) return pollAsyncJob(r, jobID) } } - if apiResponse := getResponseData(data); apiResponse != nil { - if _, ok := apiResponse["errorcode"]; ok { - return nil, fmt.Errorf("(HTTP %v, error code %v) %v", apiResponse["errorcode"], apiResponse["cserrorcode"], apiResponse["errortext"]) - } - return apiResponse, nil - } - - return nil, errors.New("failed to decode response") + return apiResponse, nil } // we can implement further conditions to do POST or GET (or other http commands) here diff --git a/cmd/network_test.go b/cmd/network_test.go new file mode 100644 index 0000000..6333dab --- /dev/null +++ b/cmd/network_test.go @@ -0,0 +1,255 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package cmd + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "path/filepath" + "strings" + "sync" + "testing" + + "github.com/apache/cloudstack-cloudmonkey/config" +) + +const ( + testAPIKey = "api-key" + testSecretKey = "secret-key" +) + +type signatureAttempt struct { + Command string + Algorithm string +} + +type signatureTestServer struct { + server *httptest.Server + attempts []signatureAttempt + mu sync.Mutex +} + +func newSignatureTestServer(t *testing.T, allowedAlgorithms map[string]bool) *signatureTestServer { + t.Helper() + testServer := &signatureTestServer{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if err := req.ParseForm(); err != nil { + http.Error(w, "failed to parse request form", http.StatusBadRequest) + return + } + + command := req.Form.Get("command") + algorithm := identifySignatureAlgorithm(req.Form) + testServer.mu.Lock() + testServer.attempts = append(testServer.attempts, signatureAttempt{ + Command: command, + Algorithm: algorithm, + }) + testServer.mu.Unlock() + + w.Header().Set("Content-Type", "application/json") + if !allowedAlgorithms[algorithm] { + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, `{"errorresponse":{"errorcode":401,"cserrorcode":9999,"errortext":"unable to verify request signature"}}`) + return + } + + fmt.Fprint(w, successResponse(command)) + })) + testServer.server = server + t.Cleanup(server.Close) + return testServer +} + +func identifySignatureAlgorithm(form url.Values) string { + unsigned := cloneRequestParams(form) + signature := unsigned.Get("signature") + unsigned.Del("signature") + + unsignedRequest := encodeRequestParams(unsigned) + for _, algorithm := range []string{config.SignatureAlgorithmHmacSHA512, config.SignatureAlgorithmHmacSHA1} { + expected, err := signRequest(unsignedRequest, testSecretKey, algorithm) + if err != nil { + return "unknown" + } + if signature == expected { + return algorithm + } + } + return "unknown" +} + +func successResponse(command string) string { + responseKey := strings.ToLower(command) + "response" + if strings.EqualFold(command, "listApis") { + return `{"listapisresponse":{"count":0,"api":[]}}` + } + return fmt.Sprintf(`{"%s":{"success":true}}`, responseKey) +} + +func newTestRequest(t *testing.T, serverURL string, signatureAlgorithm string) *Request { + t.Helper() + dir := t.TempDir() + cfg := &config.Config{ + Dir: dir, + ConfigFile: filepath.Join(dir, "config"), + HistoryFile: filepath.Join(dir, "history"), + Core: &config.Core{ + Prompt: "cmk", + AsyncBlock: true, + Timeout: 30, + Output: config.JSON, + VerifyCert: true, + ProfileName: "localcloud", + AutoComplete: true, + PostRequest: true, + }, + ActiveProfile: &config.ServerProfile{ + URL: serverURL, + Domain: "/", + APIKey: testAPIKey, + SecretKey: testSecretKey, + SignatureAlgorithm: signatureAlgorithm, + Client: http.DefaultClient, + }, + } + return NewRequest(GetAPIHandler(), cfg, nil, false) +} + +func (s *signatureTestServer) snapshotAttempts() []signatureAttempt { + s.mu.Lock() + defer s.mu.Unlock() + attempts := make([]signatureAttempt, len(s.attempts)) + copy(attempts, s.attempts) + return attempts +} + +func TestSignRequestHmacSHA1(t *testing.T) { + got, err := signRequest("apikey=abc&command=listZones&response=json", "secret", config.SignatureAlgorithmHmacSHA1) + if err != nil { + t.Fatal(err) + } + if got != "tcMI1Kpm20pLhrrVYtCCcualuBU=" { + t.Fatalf("HmacSHA1 signature = %q", got) + } +} + +func TestSignRequestHmacSHA512(t *testing.T) { + got, err := signRequest("apikey=abc&command=listZones&response=json", "secret", config.SignatureAlgorithmHmacSHA512) + if err != nil { + t.Fatal(err) + } + if got != "/JUUf9SPvNd0sjdbfIxsCxYItNUlGauI+T71cOhHd5fYffHuVAXIba9RzFBK+empezGCzlhw4+R9LFri3CG+oQ==" { + t.Fatalf("HmacSHA512 signature = %q", got) + } +} + +func TestExplicitHmacSHA512DoesNotTryHmacSHA1(t *testing.T) { + server := newSignatureTestServer(t, map[string]bool{ + config.SignatureAlgorithmHmacSHA1: true, + }) + request := newTestRequest(t, server.server.URL, config.SignatureAlgorithmHmacSHA512) + + if _, err := NewAPIRequest(request, "listZones", nil, false); err == nil { + t.Fatal("expected explicit HmacSHA512 request to fail") + } + + attempts := server.snapshotAttempts() + if len(attempts) != 1 { + t.Fatalf("attempt count = %d, want 1", len(attempts)) + } + if attempts[0].Algorithm != config.SignatureAlgorithmHmacSHA512 { + t.Fatalf("attempted algorithm = %q, want %q", attempts[0].Algorithm, config.SignatureAlgorithmHmacSHA512) + } +} + +func TestAutoPersistsHmacSHA512WhenProbeSucceeds(t *testing.T) { + server := newSignatureTestServer(t, map[string]bool{ + config.SignatureAlgorithmHmacSHA512: true, + }) + request := newTestRequest(t, server.server.URL, config.SignatureAlgorithmAuto) + + if _, err := NewAPIRequest(request, "listZones", nil, false); err != nil { + t.Fatal(err) + } + + if got := request.Config.ActiveProfile.SignatureAlgorithm; got != config.SignatureAlgorithmHmacSHA512 { + t.Fatalf("persisted signature algorithm = %q, want %q", got, config.SignatureAlgorithmHmacSHA512) + } + attempts := server.snapshotAttempts() + if len(attempts) != 2 { + t.Fatalf("attempt count = %d, want 2", len(attempts)) + } + if attempts[0] != (signatureAttempt{Command: "listApis", Algorithm: config.SignatureAlgorithmHmacSHA512}) { + t.Fatalf("first attempt = %+v, want listApis with HmacSHA512", attempts[0]) + } + if attempts[1] != (signatureAttempt{Command: "listZones", Algorithm: config.SignatureAlgorithmHmacSHA512}) { + t.Fatalf("second attempt = %+v, want listZones with HmacSHA512", attempts[1]) + } +} + +func TestAutoFallsBackAndPersistsHmacSHA1(t *testing.T) { + server := newSignatureTestServer(t, map[string]bool{ + config.SignatureAlgorithmHmacSHA1: true, + }) + request := newTestRequest(t, server.server.URL, config.SignatureAlgorithmAuto) + + if _, err := NewAPIRequest(request, "listZones", nil, false); err != nil { + t.Fatal(err) + } + + if got := request.Config.ActiveProfile.SignatureAlgorithm; got != config.SignatureAlgorithmHmacSHA1 { + t.Fatalf("persisted signature algorithm = %q, want %q", got, config.SignatureAlgorithmHmacSHA1) + } + attempts := server.snapshotAttempts() + if len(attempts) != 3 { + t.Fatalf("attempt count = %d, want 3", len(attempts)) + } + if attempts[0] != (signatureAttempt{Command: "listApis", Algorithm: config.SignatureAlgorithmHmacSHA512}) { + t.Fatalf("first attempt = %+v, want listApis with HmacSHA512", attempts[0]) + } + if attempts[1] != (signatureAttempt{Command: "listApis", Algorithm: config.SignatureAlgorithmHmacSHA1}) { + t.Fatalf("second attempt = %+v, want listApis with HmacSHA1", attempts[1]) + } + if attempts[2] != (signatureAttempt{Command: "listZones", Algorithm: config.SignatureAlgorithmHmacSHA1}) { + t.Fatalf("third attempt = %+v, want listZones with HmacSHA1", attempts[2]) + } +} + +func TestAutoDoesNotRetryUserCommandDirectly(t *testing.T) { + server := newSignatureTestServer(t, map[string]bool{ + config.SignatureAlgorithmHmacSHA1: true, + }) + request := newTestRequest(t, server.server.URL, config.SignatureAlgorithmAuto) + + if _, err := NewAPIRequest(request, "deployVirtualMachine", []string{"serviceofferingid=1"}, false); err != nil { + t.Fatal(err) + } + + userCommandAttempts := 0 + for _, attempt := range server.snapshotAttempts() { + if attempt.Command == "deployVirtualMachine" { + userCommandAttempts++ + } + } + if userCommandAttempts != 1 { + t.Fatalf("deployVirtualMachine attempt count = %d, want 1", userCommandAttempts) + } +} diff --git a/cmd/set.go b/cmd/set.go index c8bba8a..f55f064 100644 --- a/cmd/set.go +++ b/cmd/set.go @@ -31,21 +31,22 @@ func init() { Name: "set", Help: "Configures options for cmk", SubCommands: map[string][]string{ - "prompt": {"🐵", "🐱", "random"}, - "asyncblock": {"true", "false"}, - "timeout": {"600", "1800", "3600"}, - "output": config.GetOutputFormats(), - "profile": {}, - "url": {}, - "username": {}, - "password": {}, - "domain": {}, - "apikey": {}, - "secretkey": {}, - "verifycert": {"true", "false"}, - "debug": {"true", "false"}, - "autocomplete": {"true", "false"}, - "postrequest": {"true", "false"}, + "prompt": {"🐵", "🐱", "random"}, + "asyncblock": {"true", "false"}, + "timeout": {"600", "1800", "3600"}, + "output": config.GetOutputFormats(), + "profile": {}, + "url": {}, + "username": {}, + "password": {}, + "domain": {}, + "apikey": {}, + "secretkey": {}, + "signaturealgorithm": config.GetSignatureAlgorithms(), + "verifycert": {"true", "false"}, + "debug": {"true", "false"}, + "autocomplete": {"true", "false"}, + "postrequest": {"true", "false"}, }, Handle: func(r *Request) error { if len(r.Args) < 1 { @@ -63,7 +64,7 @@ func init() { subCommand = "output" } validArgs := r.Command.SubCommands[subCommand] - if len(validArgs) != 0 && subCommand != "timeout" { + if len(validArgs) != 0 && subCommand != "timeout" && subCommand != "signaturealgorithm" { if !config.CheckIfValuePresent(validArgs, value) { return errors.New("Invalid value set for " + subCommand + ". Supported values: " + strings.Join(validArgs, ", ")) } diff --git a/config/config.go b/config/config.go index cf32b69..51e4866 100644 --- a/config/config.go +++ b/config/config.go @@ -20,6 +20,7 @@ package config import ( "context" "crypto/tls" + "errors" "fmt" "net/http" "net/http/cookiejar" @@ -28,6 +29,7 @@ import ( "path" "path/filepath" "strconv" + "strings" "time" "github.com/briandowns/spinner" @@ -56,15 +58,23 @@ var nonEmptyConfigKeys = map[string]bool{ // DefaultACSAPIEndpoint is the default API endpoint for CloudStack. const DefaultACSAPIEndpoint = "http://localhost:8080/client/api" +// Supported API request signature algorithms. +const ( + SignatureAlgorithmAuto = "auto" + SignatureAlgorithmHmacSHA1 = "HmacSHA1" + SignatureAlgorithmHmacSHA512 = "HmacSHA512" +) + // ServerProfile describes a management server type ServerProfile struct { - URL string `ini:"url"` - Username string `ini:"username"` - Password string `ini:"password"` - Domain string `ini:"domain"` - APIKey string `ini:"apikey"` - SecretKey string `ini:"secretkey"` - Client *http.Client `ini:"-"` + URL string `ini:"url"` + Username string `ini:"username"` + Password string `ini:"password"` + Domain string `ini:"domain"` + APIKey string `ini:"apikey"` + SecretKey string `ini:"secretkey"` + SignatureAlgorithm string `ini:"signaturealgorithm"` + Client *http.Client `ini:"-"` } // Core block describes common options for the CLI @@ -99,6 +109,36 @@ func GetOutputFormats() []string { return []string{"column", "csv", "json", "table", "text", "default"} } +// GetSignatureAlgorithms returns the supported API request signature algorithms. +func GetSignatureAlgorithms() []string { + return []string{SignatureAlgorithmAuto, SignatureAlgorithmHmacSHA1, SignatureAlgorithmHmacSHA512} +} + +// NormalizeSignatureAlgorithm returns a canonical signature algorithm name. +func NormalizeSignatureAlgorithm(value string) (string, error) { + switch strings.ToLower(strings.TrimSpace(value)) { + case "", SignatureAlgorithmAuto: + return SignatureAlgorithmAuto, nil + case strings.ToLower(SignatureAlgorithmHmacSHA1): + return SignatureAlgorithmHmacSHA1, nil + case strings.ToLower(SignatureAlgorithmHmacSHA512): + return SignatureAlgorithmHmacSHA512, nil + default: + return "", errors.New("unsupported signature algorithm") + } +} + +func normalizeProfile(profile *ServerProfile) { + if profile == nil { + return + } + signatureAlgorithm, err := NormalizeSignatureAlgorithm(profile.SignatureAlgorithm) + if err != nil { + signatureAlgorithm = SignatureAlgorithmAuto + } + profile.SignatureAlgorithm = signatureAlgorithm +} + // CheckIfValuePresent checks if an element is present in the dataset. func CheckIfValuePresent(dataset []string, element string) bool { for _, arg := range dataset { @@ -170,12 +210,13 @@ func defaultCoreConfig() Core { func defaultProfile() ServerProfile { return ServerProfile{ - URL: DefaultACSAPIEndpoint, - Username: "admin", - Password: "password", - Domain: "/", - APIKey: "", - SecretKey: "", + URL: DefaultACSAPIEndpoint, + Username: "admin", + Password: "password", + Domain: "/", + APIKey: "", + SecretKey: "", + SignatureAlgorithm: SignatureAlgorithmAuto, } } @@ -231,6 +272,7 @@ func newHTTPClient(cfg *Config) *http.Client { } func setActiveProfile(cfg *Config, profile *ServerProfile) { + normalizeProfile(profile) cfg.ActiveProfile = profile cfg.ActiveProfile.Client = newHTTPClient(cfg) } @@ -320,6 +362,19 @@ func saveConfig(cfg *Config) *Config { conf.Section(cfg.Core.ProfileName).MapTo(profile) setActiveProfile(cfg, profile) } + for _, section := range conf.Sections() { + if section.Name() == ini.DEFAULT_SECTION { + continue + } + signatureAlgorithm, err := NormalizeSignatureAlgorithm(section.Key("signaturealgorithm").String()) + if err != nil { + signatureAlgorithm = SignatureAlgorithmAuto + } + section.Key("signaturealgorithm").SetValue(signatureAlgorithm) + if section.Name() == cfg.Core.ProfileName && cfg.ActiveProfile != nil { + cfg.ActiveProfile.SignatureAlgorithm = signatureAlgorithm + } + } // Save conf.SaveTo(cfg.ConfigFile) @@ -391,6 +446,13 @@ func (c *Config) UpdateConfig(key string, value string, update bool) { c.ActiveProfile.APIKey = value case "secretkey": c.ActiveProfile.SecretKey = value + case "signaturealgorithm": + signatureAlgorithm, err := NormalizeSignatureAlgorithm(value) + if err != nil { + fmt.Println("Invalid value set for signaturealgorithm. Supported values: " + strings.Join(GetSignatureAlgorithms(), ", ")) + return + } + c.ActiveProfile.SignatureAlgorithm = signatureAlgorithm case "verifycert": c.Core.VerifyCert = value == "true" case "debug": diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 0000000..cebd6fe --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package config + +import ( + "path/filepath" + "strings" + "testing" + + ini "gopkg.in/ini.v1" +) + +func testCore() *Core { + return &Core{ + Prompt: "cmk", + AsyncBlock: true, + Timeout: 1800, + Output: JSON, + VerifyCert: true, + ProfileName: "localcloud", + AutoComplete: true, + PostRequest: true, + } +} + +func TestDefaultProfileSignatureAlgorithmAuto(t *testing.T) { + profile := defaultProfile() + if profile.SignatureAlgorithm != SignatureAlgorithmAuto { + t.Fatalf("default signature algorithm = %q, want %q", profile.SignatureAlgorithm, SignatureAlgorithmAuto) + } +} + +func TestNormalizeSignatureAlgorithmCanonicalizesInput(t *testing.T) { + tests := map[string]string{ + "": SignatureAlgorithmAuto, + " AUTO ": SignatureAlgorithmAuto, + "hmacsha1": SignatureAlgorithmHmacSHA1, + "HMACSHA512": SignatureAlgorithmHmacSHA512, + "HmacSHA512": SignatureAlgorithmHmacSHA512, + " HmacSHA1 ": SignatureAlgorithmHmacSHA1, + } + + for input, want := range tests { + got, err := NormalizeSignatureAlgorithm(input) + if err != nil { + t.Fatalf("NormalizeSignatureAlgorithm(%q) returned error: %v", input, err) + } + if got != want { + t.Fatalf("NormalizeSignatureAlgorithm(%q) = %q, want %q", input, got, want) + } + } +} + +func TestSaveConfigMigratesMissingSignatureAlgorithmToAuto(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "config") + conf := ini.Empty() + conf.Section(ini.DEFAULT_SECTION).ReflectFrom(testCore()) + conf.Section("localcloud").ReflectFrom(&ServerProfile{ + URL: DefaultACSAPIEndpoint, + Username: "admin", + Password: "password", + Domain: "/", + APIKey: "api-key", + SecretKey: "secret-key", + }) + conf.Section("localcloud").DeleteKey("signaturealgorithm") + if err := conf.SaveTo(configFile); err != nil { + t.Fatal(err) + } + + cfg := &Config{ + Dir: dir, + ConfigFile: configFile, + HistoryFile: filepath.Join(dir, "history"), + Core: testCore(), + } + saveConfig(cfg) + + updated := readConfig(cfg) + got := updated.Section("localcloud").Key("signaturealgorithm").String() + if got != SignatureAlgorithmAuto { + t.Fatalf("persisted signaturealgorithm = %q, want %q", got, SignatureAlgorithmAuto) + } + if cfg.ActiveProfile.SignatureAlgorithm != SignatureAlgorithmAuto { + t.Fatalf("active profile signature algorithm = %q, want %q", cfg.ActiveProfile.SignatureAlgorithm, SignatureAlgorithmAuto) + } +} + +func TestPromptAddsFIPSIndicatorForHmacSHA512(t *testing.T) { + cfg := &Config{ + Core: testCore(), + ActiveProfile: &ServerProfile{ + SignatureAlgorithm: SignatureAlgorithmHmacSHA512, + }, + } + + if got := cfg.GetPrompt(); !strings.Contains(got, "(localcloud-fips)") { + t.Fatalf("prompt = %q, want FIPS profile indicator", got) + } +} diff --git a/config/prompt.go b/config/prompt.go index 1e3ae40..4625191 100644 --- a/config/prompt.go +++ b/config/prompt.go @@ -48,5 +48,9 @@ func renderPrompt(prompt string) string { // GetPrompt returns prompt that the CLI should use func (c *Config) GetPrompt() string { - return fmt.Sprintf("(%s) %s > ", c.Core.ProfileName, renderPrompt(c.Core.Prompt)) + profileName := c.Core.ProfileName + if c.ActiveProfile != nil && c.ActiveProfile.SignatureAlgorithm == SignatureAlgorithmHmacSHA512 { + profileName += "-fips" + } + return fmt.Sprintf("(%s) %s > ", profileName, renderPrompt(c.Core.Prompt)) }