diff --git a/api/client.go b/api/client.go index 53b87ae20..f7ca60b67 100644 --- a/api/client.go +++ b/api/client.go @@ -25,8 +25,6 @@ import ( "time" ) -type Warnings []string - // DefaultRoundTripper is used if no RoundTripper is set in Config. var DefaultRoundTripper http.RoundTripper = &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -57,32 +55,7 @@ func (cfg *Config) roundTripper() http.RoundTripper { // Client is the interface for an API client. type Client interface { URL(ep string, args map[string]string) *url.URL - Do(context.Context, *http.Request) (*http.Response, []byte, Warnings, error) -} - -// DoGetFallback will attempt to do the request as-is, and on a 405 it will fallback to a GET request. -func DoGetFallback(c Client, ctx context.Context, u *url.URL, args url.Values) (*http.Response, []byte, Warnings, error) { - req, err := http.NewRequest(http.MethodPost, u.String(), strings.NewReader(args.Encode())) - if err != nil { - return nil, nil, nil, err - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, body, warnings, err := c.Do(ctx, req) - if resp != nil && resp.StatusCode == http.StatusMethodNotAllowed { - u.RawQuery = args.Encode() - req, err = http.NewRequest(http.MethodGet, u.String(), nil) - if err != nil { - return nil, nil, warnings, err - } - - } else { - if err != nil { - return resp, body, warnings, err - } - return resp, body, warnings, nil - } - return c.Do(ctx, req) + Do(context.Context, *http.Request) (*http.Response, []byte, error) } // NewClient returns a new Client. @@ -120,7 +93,7 @@ func (c *httpClient) URL(ep string, args map[string]string) *url.URL { return &u } -func (c *httpClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, Warnings, error) { +func (c *httpClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, error) { if ctx != nil { req = req.WithContext(ctx) } @@ -132,7 +105,7 @@ func (c *httpClient) Do(ctx context.Context, req *http.Request) (*http.Response, }() if err != nil { - return nil, nil, nil, err + return nil, nil, err } var body []byte @@ -152,5 +125,5 @@ func (c *httpClient) Do(ctx context.Context, req *http.Request) (*http.Response, case <-done: } - return resp, body, nil, err + return resp, body, err } diff --git a/api/client_test.go b/api/client_test.go index b3c95eee6..47094fccd 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -14,10 +14,7 @@ package api import ( - "context" - "encoding/json" "net/http" - "net/http/httptest" "net/url" "testing" ) @@ -114,72 +111,3 @@ func TestClientURL(t *testing.T) { } } } - -func TestDoGetFallback(t *testing.T) { - v := url.Values{"a": []string{"1", "2"}} - - type testResponse struct { - Values string - Method string - } - - // Start a local HTTP server. - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - req.ParseForm() - r := &testResponse{ - Values: req.Form.Encode(), - Method: req.Method, - } - - body, _ := json.Marshal(r) - - if req.Method == http.MethodPost { - if req.URL.Path == "/blockPost" { - http.Error(w, string(body), http.StatusMethodNotAllowed) - return - } - } - - w.Write(body) - })) - // Close the server when test finishes. - defer server.Close() - - u, err := url.Parse(server.URL) - if err != nil { - t.Fatal(err) - } - client := &httpClient{client: *(server.Client())} - - // Do a post, and ensure that the post succeeds. - _, b, _, err := DoGetFallback(client, context.TODO(), u, v) - if err != nil { - t.Fatalf("Error doing local request: %v", err) - } - resp := &testResponse{} - if err := json.Unmarshal(b, resp); err != nil { - t.Fatal(err) - } - if resp.Method != http.MethodPost { - t.Fatalf("Mismatch method") - } - if resp.Values != v.Encode() { - t.Fatalf("Mismatch in values") - } - - // Do a fallbcak to a get. - u.Path = "/blockPost" - _, b, _, err = DoGetFallback(client, context.TODO(), u, v) - if err != nil { - t.Fatalf("Error doing local request: %v", err) - } - if err := json.Unmarshal(b, resp); err != nil { - t.Fatal(err) - } - if resp.Method != http.MethodGet { - t.Fatalf("Mismatch method") - } - if resp.Values != v.Encode() { - t.Fatalf("Mismatch in values") - } -} diff --git a/api/prometheus/v1/api.go b/api/prometheus/v1/api.go index 00cd76ea2..06319a89f 100644 --- a/api/prometheus/v1/api.go +++ b/api/prometheus/v1/api.go @@ -21,7 +21,9 @@ import ( "fmt" "math" "net/http" + "net/url" "strconv" + "strings" "time" "unsafe" @@ -228,15 +230,15 @@ type API interface { // Flags returns the flag values that Prometheus was launched with. Flags(ctx context.Context) (FlagsResult, error) // LabelNames returns all the unique label names present in the block in sorted order. - LabelNames(ctx context.Context) ([]string, api.Warnings, error) + LabelNames(ctx context.Context) ([]string, Warnings, error) // LabelValues performs a query for the values of the given label. - LabelValues(ctx context.Context, label string) (model.LabelValues, api.Warnings, error) + LabelValues(ctx context.Context, label string) (model.LabelValues, Warnings, error) // Query performs a query for the given time. - Query(ctx context.Context, query string, ts time.Time) (model.Value, api.Warnings, error) + Query(ctx context.Context, query string, ts time.Time) (model.Value, Warnings, error) // QueryRange performs a query for the given range. - QueryRange(ctx context.Context, query string, r Range) (model.Value, api.Warnings, error) + QueryRange(ctx context.Context, query string, r Range) (model.Value, Warnings, error) // Series finds series by label matchers. - Series(ctx context.Context, matches []string, startTime time.Time, endTime time.Time) ([]model.LabelSet, api.Warnings, error) + Series(ctx context.Context, matches []string, startTime time.Time, endTime time.Time) ([]model.LabelSet, Warnings, error) // Snapshot creates a snapshot of all current data into snapshots/- // under the TSDB's data directory and returns the directory as response. Snapshot(ctx context.Context, skipHead bool) (SnapshotResult, error) @@ -515,11 +517,15 @@ func (qr *queryResult) UnmarshalJSON(b []byte) error { // // It is safe to use the returned API from multiple goroutines. func NewAPI(c api.Client) API { - return &httpAPI{client: apiClient{c}} + return &httpAPI{ + client: &apiClientImpl{ + client: c, + }, + } } type httpAPI struct { - client api.Client + client apiClient } func (h *httpAPI) Alerts(ctx context.Context) (AlertsResult, error) { @@ -624,7 +630,7 @@ func (h *httpAPI) Flags(ctx context.Context) (FlagsResult, error) { return res, json.Unmarshal(body, &res) } -func (h *httpAPI) LabelNames(ctx context.Context) ([]string, api.Warnings, error) { +func (h *httpAPI) LabelNames(ctx context.Context) ([]string, Warnings, error) { u := h.client.URL(epLabels, nil) req, err := http.NewRequest(http.MethodGet, u.String(), nil) if err != nil { @@ -638,7 +644,7 @@ func (h *httpAPI) LabelNames(ctx context.Context) ([]string, api.Warnings, error return labelNames, w, json.Unmarshal(body, &labelNames) } -func (h *httpAPI) LabelValues(ctx context.Context, label string) (model.LabelValues, api.Warnings, error) { +func (h *httpAPI) LabelValues(ctx context.Context, label string) (model.LabelValues, Warnings, error) { u := h.client.URL(epLabelValues, map[string]string{"name": label}) req, err := http.NewRequest(http.MethodGet, u.String(), nil) if err != nil { @@ -652,7 +658,7 @@ func (h *httpAPI) LabelValues(ctx context.Context, label string) (model.LabelVal return labelValues, w, json.Unmarshal(body, &labelValues) } -func (h *httpAPI) Query(ctx context.Context, query string, ts time.Time) (model.Value, api.Warnings, error) { +func (h *httpAPI) Query(ctx context.Context, query string, ts time.Time) (model.Value, Warnings, error) { u := h.client.URL(epQuery, nil) q := u.Query() @@ -661,7 +667,7 @@ func (h *httpAPI) Query(ctx context.Context, query string, ts time.Time) (model. q.Set("time", formatTime(ts)) } - _, body, warnings, err := api.DoGetFallback(h.client, ctx, u, q) + _, body, warnings, err := h.client.DoGetFallback(ctx, u, q) if err != nil { return nil, warnings, err } @@ -670,7 +676,7 @@ func (h *httpAPI) Query(ctx context.Context, query string, ts time.Time) (model. return model.Value(qres.v), warnings, json.Unmarshal(body, &qres) } -func (h *httpAPI) QueryRange(ctx context.Context, query string, r Range) (model.Value, api.Warnings, error) { +func (h *httpAPI) QueryRange(ctx context.Context, query string, r Range) (model.Value, Warnings, error) { u := h.client.URL(epQueryRange, nil) q := u.Query() @@ -679,7 +685,7 @@ func (h *httpAPI) QueryRange(ctx context.Context, query string, r Range) (model. q.Set("end", formatTime(r.End)) q.Set("step", strconv.FormatFloat(r.Step.Seconds(), 'f', -1, 64)) - _, body, warnings, err := api.DoGetFallback(h.client, ctx, u, q) + _, body, warnings, err := h.client.DoGetFallback(ctx, u, q) if err != nil { return nil, warnings, err } @@ -689,7 +695,7 @@ func (h *httpAPI) QueryRange(ctx context.Context, query string, r Range) (model. return model.Value(qres.v), warnings, json.Unmarshal(body, &qres) } -func (h *httpAPI) Series(ctx context.Context, matches []string, startTime time.Time, endTime time.Time) ([]model.LabelSet, api.Warnings, error) { +func (h *httpAPI) Series(ctx context.Context, matches []string, startTime time.Time, endTime time.Time) ([]model.LabelSet, Warnings, error) { u := h.client.URL(epSeries, nil) q := u.Query() @@ -796,10 +802,19 @@ func (h *httpAPI) TargetsMetadata(ctx context.Context, matchTarget string, metri return res, json.Unmarshal(body, &res) } +// Warnings is an array of non critical errors +type Warnings []string + // apiClient wraps a regular client and processes successful API responses. // Successful also includes responses that errored at the API level. -type apiClient struct { - api.Client +type apiClient interface { + URL(ep string, args map[string]string) *url.URL + Do(context.Context, *http.Request) (*http.Response, []byte, Warnings, error) + DoGetFallback(ctx context.Context, u *url.URL, args url.Values) (*http.Response, []byte, Warnings, error) +} + +type apiClientImpl struct { + client api.Client } type apiResponse struct { @@ -825,17 +840,21 @@ func errorTypeAndMsgFor(resp *http.Response) (ErrorType, string) { return ErrBadResponse, fmt.Sprintf("bad response code %d", resp.StatusCode) } -func (c apiClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, api.Warnings, error) { - resp, body, warnings, err := c.Client.Do(ctx, req) +func (h *apiClientImpl) URL(ep string, args map[string]string) *url.URL { + return h.client.URL(ep, args) +} + +func (h *apiClientImpl) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, Warnings, error) { + resp, body, err := h.client.Do(ctx, req) if err != nil { - return resp, body, warnings, err + return resp, body, nil, err } code := resp.StatusCode if code/100 != 2 && !apiError(code) { errorType, errorMsg := errorTypeAndMsgFor(resp) - return resp, body, warnings, &Error{ + return resp, body, nil, &Error{ Type: errorType, Msg: errorMsg, Detail: string(body), @@ -846,7 +865,7 @@ func (c apiClient) Do(ctx context.Context, req *http.Request) (*http.Response, [ if http.StatusNoContent != code { if jsonErr := json.Unmarshal(body, &result); jsonErr != nil { - return resp, body, warnings, &Error{ + return resp, body, nil, &Error{ Type: ErrBadResponse, Msg: jsonErr.Error(), } @@ -867,10 +886,35 @@ func (c apiClient) Do(ctx context.Context, req *http.Request) (*http.Response, [ } } - return resp, []byte(result.Data), warnings, err + return resp, []byte(result.Data), result.Warnings, err } +// DoGetFallback will attempt to do the request as-is, and on a 405 it will fallback to a GET request. +func (h *apiClientImpl) DoGetFallback(ctx context.Context, u *url.URL, args url.Values) (*http.Response, []byte, Warnings, error) { + req, err := http.NewRequest(http.MethodPost, u.String(), strings.NewReader(args.Encode())) + if err != nil { + return nil, nil, nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, body, warnings, err := h.Do(ctx, req) + if resp != nil && resp.StatusCode == http.StatusMethodNotAllowed { + u.RawQuery = args.Encode() + req, err = http.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return nil, nil, warnings, err + } + + } else { + if err != nil { + return resp, body, warnings, err + } + return resp, body, warnings, nil + } + return h.Do(ctx, req) +} + func formatTime(t time.Time) string { return strconv.FormatFloat(float64(t.Unix())+float64(t.Nanosecond())/1e9, 'f', -1, 64) } diff --git a/api/prometheus/v1/api_test.go b/api/prometheus/v1/api_test.go index 0572f8650..b43318e65 100644 --- a/api/prometheus/v1/api_test.go +++ b/api/prometheus/v1/api_test.go @@ -17,8 +17,10 @@ import ( "context" "errors" "fmt" + "io/ioutil" "math" "net/http" + "net/http/httptest" "net/url" "reflect" "strings" @@ -28,12 +30,10 @@ import ( json "github.com/json-iterator/go" "github.com/prometheus/common/model" - - "github.com/prometheus/client_golang/api" ) type apiTest struct { - do func() (interface{}, api.Warnings, error) + do func() (interface{}, Warnings, error) inWarnings []string inErr error inStatusCode int @@ -43,7 +43,7 @@ type apiTest struct { reqParam url.Values reqMethod string res interface{} - warnings api.Warnings + warnings Warnings err error } @@ -64,7 +64,7 @@ func (c *apiTestClient) URL(ep string, args map[string]string) *url.URL { return u } -func (c *apiTestClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, api.Warnings, error) { +func (c *apiTestClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, Warnings, error) { test := c.curTest @@ -92,102 +92,111 @@ func (c *apiTestClient) Do(ctx context.Context, req *http.Request) (*http.Respon return resp, b, test.inWarnings, test.inErr } +func (c *apiTestClient) DoGetFallback(ctx context.Context, u *url.URL, args url.Values) (*http.Response, []byte, Warnings, error) { + req, err := http.NewRequest(http.MethodPost, u.String(), strings.NewReader(args.Encode())) + if err != nil { + return nil, nil, nil, err + } + return c.Do(ctx, req) +} + func TestAPIs(t *testing.T) { testTime := time.Now() - client := &apiTestClient{T: t} - + tc := &apiTestClient{ + T: t, + } promAPI := &httpAPI{ - client: client, + client: tc, } - doAlertManagers := func() func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doAlertManagers := func() func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { v, err := promAPI.AlertManagers(context.Background()) return v, nil, err } } - doCleanTombstones := func() func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doCleanTombstones := func() func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { return nil, nil, promAPI.CleanTombstones(context.Background()) } } - doConfig := func() func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doConfig := func() func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { v, err := promAPI.Config(context.Background()) return v, nil, err } } - doDeleteSeries := func(matcher string, startTime time.Time, endTime time.Time) func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doDeleteSeries := func(matcher string, startTime time.Time, endTime time.Time) func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { return nil, nil, promAPI.DeleteSeries(context.Background(), []string{matcher}, startTime, endTime) } } - doFlags := func() func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doFlags := func() func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { v, err := promAPI.Flags(context.Background()) return v, nil, err } } - doLabelNames := func(label string) func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doLabelNames := func(label string) func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { return promAPI.LabelNames(context.Background()) } } - doLabelValues := func(label string) func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doLabelValues := func(label string) func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { return promAPI.LabelValues(context.Background(), label) } } - doQuery := func(q string, ts time.Time) func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doQuery := func(q string, ts time.Time) func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { return promAPI.Query(context.Background(), q, ts) } } - doQueryRange := func(q string, rng Range) func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doQueryRange := func(q string, rng Range) func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { return promAPI.QueryRange(context.Background(), q, rng) } } - doSeries := func(matcher string, startTime time.Time, endTime time.Time) func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doSeries := func(matcher string, startTime time.Time, endTime time.Time) func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { return promAPI.Series(context.Background(), []string{matcher}, startTime, endTime) } } - doSnapshot := func(skipHead bool) func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doSnapshot := func(skipHead bool) func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { v, err := promAPI.Snapshot(context.Background(), skipHead) return v, nil, err } } - doRules := func() func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doRules := func() func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { v, err := promAPI.Rules(context.Background()) return v, nil, err } } - doTargets := func() func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doTargets := func() func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { v, err := promAPI.Targets(context.Background()) return v, nil, err } } - doTargetsMetadata := func(matchTarget string, metric string, limit string) func() (interface{}, api.Warnings, error) { - return func() (interface{}, api.Warnings, error) { + doTargetsMetadata := func(matchTarget string, metric string, limit string) func() (interface{}, Warnings, error) { + return func() (interface{}, Warnings, error) { v, err := promAPI.TargetsMetadata(context.Background(), matchTarget, metric, limit) return v, nil, err } @@ -855,7 +864,7 @@ func TestAPIs(t *testing.T) { for i, test := range tests { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - client.curTest = test + tc.curTest = test res, warnings, err := test.do() @@ -900,14 +909,14 @@ type apiClientTest struct { response interface{} expectedBody string expectedErr *Error - expectedWarnings api.Warnings + expectedWarnings Warnings } func (c *testClient) URL(ep string, args map[string]string) *url.URL { return nil } -func (c *testClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, api.Warnings, error) { +func (c *testClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, error) { if ctx == nil { c.Fatalf("context was not passed down") } @@ -934,7 +943,7 @@ func (c *testClient) Do(ctx context.Context, req *http.Request) (*http.Response, StatusCode: test.code, } - return resp, b, test.expectedWarnings, nil + return resp, b, nil } func TestAPIClientDo(t *testing.T) { @@ -1065,7 +1074,9 @@ func TestAPIClientDo(t *testing.T) { ch: make(chan apiClientTest, 1), req: &http.Request{}, } - client := &apiClient{tc} + client := &apiClientImpl{ + client: tc, + } for i, test := range tests { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { @@ -1209,3 +1220,113 @@ func TestSamplesJsonSerialization(t *testing.T) { }) } } + +type httpTestClient struct { + client http.Client +} + +func (c *httpTestClient) URL(ep string, args map[string]string) *url.URL { + return nil +} + +func (c *httpTestClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, error) { + resp, err := c.client.Do(req) + if err != nil { + return nil, nil, err + } + + var body []byte + done := make(chan struct{}) + go func() { + body, err = ioutil.ReadAll(resp.Body) + close(done) + }() + + select { + case <-ctx.Done(): + <-done + err = resp.Body.Close() + if err == nil { + err = ctx.Err() + } + case <-done: + } + + return resp, body, err +} + +func TestDoGetFallback(t *testing.T) { + v := url.Values{"a": []string{"1", "2"}} + + type testResponse struct { + Values string + Method string + } + + // Start a local HTTP server. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + req.ParseForm() + testResp, _ := json.Marshal(&testResponse{ + Values: req.Form.Encode(), + Method: req.Method, + }) + + apiResp := &apiResponse{ + Data: testResp, + } + + body, _ := json.Marshal(apiResp) + + if req.Method == http.MethodPost { + if req.URL.Path == "/blockPost" { + http.Error(w, string(body), http.StatusMethodNotAllowed) + return + } + } + + w.Write(body) + })) + // Close the server when test finishes. + defer server.Close() + + u, err := url.Parse(server.URL) + if err != nil { + t.Fatal(err) + } + client := &httpTestClient{client: *(server.Client())} + api := &apiClientImpl{ + client: client, + } + + // Do a post, and ensure that the post succeeds. + _, b, _, err := api.DoGetFallback(context.TODO(), u, v) + if err != nil { + t.Fatalf("Error doing local request: %v", err) + } + resp := &testResponse{} + if err := json.Unmarshal(b, resp); err != nil { + t.Fatal(err) + } + if resp.Method != http.MethodPost { + t.Fatalf("Mismatch method") + } + if resp.Values != v.Encode() { + t.Fatalf("Mismatch in values") + } + + // Do a fallbcak to a get. + u.Path = "/blockPost" + _, b, _, err = api.DoGetFallback(context.TODO(), u, v) + if err != nil { + t.Fatalf("Error doing local request: %v", err) + } + if err := json.Unmarshal(b, resp); err != nil { + t.Fatal(err) + } + if resp.Method != http.MethodGet { + t.Fatalf("Mismatch method") + } + if resp.Values != v.Encode() { + t.Fatalf("Mismatch in values") + } +}