Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

revamp DNS firewall methods, remove virtual DNS #1313

Merged
merged 4 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changelog/1313.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
```release-note:breaking-change
virtualdns: remove support in favour of newer DNS firewall methods
```

```release-note:breaking-change
dns_firewall: modernise method signatures and conventions to align with the experimental client
```
96 changes: 60 additions & 36 deletions dns_firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package cloudflare
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)

var ErrMissingClusterID = errors.New("missing required cluster ID")

// DNSFirewallCluster represents a DNS Firewall configuration.
type DNSFirewallCluster struct {
ID string `json:"id,omitempty"`
Expand Down Expand Up @@ -43,9 +44,9 @@ type DNSFirewallAnalytics struct {

// DNSFirewallUserAnalyticsOptions represents range and dimension selection on analytics endpoint.
type DNSFirewallUserAnalyticsOptions struct {
Metrics []string
Since *time.Time
Until *time.Time
Metrics []string `url:"metrics,omitempty" del:","`
Since *time.Time `url:"since,omitempty"`
Until *time.Time `url:"until,omitempty"`
}

// dnsFirewallResponse represents a DNS Firewall response.
Expand All @@ -66,12 +67,42 @@ type dnsFirewallAnalyticsResponse struct {
Result DNSFirewallAnalytics `json:"result"`
}

type CreateDNSFirewallClusterParams struct {
Name string `json:"name"`
UpstreamIPs []string `json:"upstream_ips"`
DNSFirewallIPs []string `json:"dns_firewall_ips,omitempty"`
MinimumCacheTTL uint `json:"minimum_cache_ttl,omitempty"`
MaximumCacheTTL uint `json:"maximum_cache_ttl,omitempty"`
DeprecateAnyRequests bool `json:"deprecate_any_requests"`
}

type GetDNSFirewallClusterParams struct {
ClusterID string `json:"-"`
}

type UpdateDNSFirewallClusterParams struct {
ClusterID string `json:"-"`
Name string `json:"name"`
UpstreamIPs []string `json:"upstream_ips"`
DNSFirewallIPs []string `json:"dns_firewall_ips,omitempty"`
MinimumCacheTTL uint `json:"minimum_cache_ttl,omitempty"`
MaximumCacheTTL uint `json:"maximum_cache_ttl,omitempty"`
DeprecateAnyRequests bool `json:"deprecate_any_requests"`
}

type ListDNSFirewallClustersParams struct{}

type GetDNSFirewallUserAnalyticsParams struct {
ClusterID string `json:"-"`
DNSFirewallUserAnalyticsOptions
}

// CreateDNSFirewallCluster creates a new DNS Firewall cluster.
//
// API reference: https://api.cloudflare.com/#dns-firewall-create-dns-firewall-cluster
func (api *API) CreateDNSFirewallCluster(ctx context.Context, v DNSFirewallCluster) (*DNSFirewallCluster, error) {
uri := fmt.Sprintf("%s/dns_firewall", api.userBaseURL("/user"))
res, err := api.makeRequestContext(ctx, http.MethodPost, uri, v)
func (api *API) CreateDNSFirewallCluster(ctx context.Context, rc *ResourceContainer, params CreateDNSFirewallClusterParams) (*DNSFirewallCluster, error) {
uri := fmt.Sprintf("/%s/dns_firewall", rc.URLFragment())
res, err := api.makeRequestContext(ctx, http.MethodPost, uri, params)
if err != nil {
return nil, err
}
Expand All @@ -85,11 +116,15 @@ func (api *API) CreateDNSFirewallCluster(ctx context.Context, v DNSFirewallClust
return response.Result, nil
}

// DNSFirewallCluster fetches a single DNS Firewall cluster.
// GetDNSFirewallCluster fetches a single DNS Firewall cluster.
//
// API reference: https://api.cloudflare.com/#dns-firewall-dns-firewall-cluster-details
func (api *API) DNSFirewallCluster(ctx context.Context, clusterID string) (*DNSFirewallCluster, error) {
uri := fmt.Sprintf("%s/dns_firewall/%s", api.userBaseURL("/user"), clusterID)
func (api *API) GetDNSFirewallCluster(ctx context.Context, rc *ResourceContainer, params GetDNSFirewallClusterParams) (*DNSFirewallCluster, error) {
if params.ClusterID == "" {
return &DNSFirewallCluster{}, ErrMissingClusterID
}

uri := fmt.Sprintf("/%s/dns_firewall/%s", rc.URLFragment(), params.ClusterID)
res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return nil, err
Expand All @@ -107,8 +142,8 @@ func (api *API) DNSFirewallCluster(ctx context.Context, clusterID string) (*DNSF
// ListDNSFirewallClusters lists the DNS Firewall clusters associated with an account.
//
// API reference: https://api.cloudflare.com/#dns-firewall-list-dns-firewall-clusters
func (api *API) ListDNSFirewallClusters(ctx context.Context) ([]*DNSFirewallCluster, error) {
uri := fmt.Sprintf("%s/dns_firewall", api.userBaseURL("/user"))
func (api *API) ListDNSFirewallClusters(ctx context.Context, rc *ResourceContainer, params ListDNSFirewallClustersParams) ([]*DNSFirewallCluster, error) {
uri := fmt.Sprintf("/%s/dns_firewall", rc.URLFragment())
res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return nil, err
Expand All @@ -126,9 +161,13 @@ func (api *API) ListDNSFirewallClusters(ctx context.Context) ([]*DNSFirewallClus
// UpdateDNSFirewallCluster updates a DNS Firewall cluster.
//
// API reference: https://api.cloudflare.com/#dns-firewall-update-dns-firewall-cluster
func (api *API) UpdateDNSFirewallCluster(ctx context.Context, clusterID string, vv DNSFirewallCluster) error {
uri := fmt.Sprintf("%s/dns_firewall/%s", api.userBaseURL("/user"), clusterID)
res, err := api.makeRequestContext(ctx, http.MethodPatch, uri, vv)
func (api *API) UpdateDNSFirewallCluster(ctx context.Context, rc *ResourceContainer, params UpdateDNSFirewallClusterParams) error {
if params.ClusterID == "" {
return ErrMissingClusterID
}

uri := fmt.Sprintf("/%s/dns_firewall/%s", rc.URLFragment(), params.ClusterID)
res, err := api.makeRequestContext(ctx, http.MethodPatch, uri, params)
if err != nil {
return err
}
Expand All @@ -146,8 +185,8 @@ func (api *API) UpdateDNSFirewallCluster(ctx context.Context, clusterID string,
// undone, and will stop all traffic to that cluster.
//
// API reference: https://api.cloudflare.com/#dns-firewall-delete-dns-firewall-cluster
func (api *API) DeleteDNSFirewallCluster(ctx context.Context, clusterID string) error {
uri := fmt.Sprintf("%s/dns_firewall/%s", api.userBaseURL("/user"), clusterID)
func (api *API) DeleteDNSFirewallCluster(ctx context.Context, rc *ResourceContainer, clusterID string) error {
uri := fmt.Sprintf("/%s/dns_firewall/%s", rc.URLFragment(), clusterID)
res, err := api.makeRequestContext(ctx, http.MethodDelete, uri, nil)
if err != nil {
return err
Expand All @@ -162,24 +201,9 @@ func (api *API) DeleteDNSFirewallCluster(ctx context.Context, clusterID string)
return nil
}

// encode encodes non-nil fields into URL encoded form.
func (o DNSFirewallUserAnalyticsOptions) encode() string {
v := url.Values{}
if o.Since != nil {
v.Set("since", o.Since.UTC().Format(time.RFC3339))
}
if o.Until != nil {
v.Set("until", o.Until.UTC().Format(time.RFC3339))
}
if o.Metrics != nil {
v.Set("metrics", strings.Join(o.Metrics, ","))
}
return v.Encode()
}

// DNSFirewallUserAnalytics retrieves analytics report for a specified dimension and time range.
func (api *API) DNSFirewallUserAnalytics(ctx context.Context, clusterID string, o DNSFirewallUserAnalyticsOptions) (DNSFirewallAnalytics, error) {
uri := fmt.Sprintf("%s/dns_firewall/%s/dns_analytics/report?%s", api.userBaseURL("/user"), clusterID, o.encode())
// GetDNSFirewallUserAnalytics retrieves analytics report for a specified dimension and time range.
func (api *API) GetDNSFirewallUserAnalytics(ctx context.Context, rc *ResourceContainer, params GetDNSFirewallUserAnalyticsParams) (DNSFirewallAnalytics, error) {
uri := buildURI(fmt.Sprintf("/%s/dns_firewall/%s/dns_analytics/report", rc.URLFragment(), params.ClusterID), params.DNSFirewallUserAnalyticsOptions)
res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return DNSFirewallAnalytics{}, err
Expand Down
93 changes: 76 additions & 17 deletions dns_firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,74 @@ import (
"github.com/stretchr/testify/assert"
)

func float64Ptr(v float64) *float64 {
return &v
}
func TestDNSFirewallUserAnalytics_UserLevel(t *testing.T) {
setup()
defer teardown()

now := time.Now().UTC()
since := now.Add(-1 * time.Hour)
until := now

handler := func(w http.ResponseWriter, r *http.Request) {
expectedMetrics := "queryCount,uncachedCount,staleCount,responseTimeAvg,responseTimeMedia,responseTime90th,responseTime99th"

assert.Equal(t, http.MethodGet, r.Method, "Expected method 'GET'")
assert.Equal(t, expectedMetrics, r.URL.Query().Get("metrics"), "Expected many metrics in URL parameter")
assert.Equal(t, since.Format(time.RFC3339), r.URL.Query().Get("since"), "Expected since parameter in URL")
assert.Equal(t, until.Format(time.RFC3339), r.URL.Query().Get("until"), "Expected until parameter in URL")

w.Header().Set("content-type", "application/json")
fmt.Fprint(w, `{
"result": {
"totals":{
"queryCount": 5,
"uncachedCount":6,
"staleCount":7,
"responseTimeAvg":1.0,
"responseTimeMedian":2.0,
"responseTime90th":3.0,
"responseTime99th":4.0
}
},
"success": true,
"errors": null,
"messages": null
}`)
}

func int64Ptr(v int64) *int64 {
return &v
mux.HandleFunc("/user/dns_firewall/12345/dns_analytics/report", handler)
want := DNSFirewallAnalytics{
Totals: DNSFirewallAnalyticsMetrics{
QueryCount: Int64Ptr(5),
UncachedCount: Int64Ptr(6),
StaleCount: Int64Ptr(7),
ResponseTimeAvg: Float64Ptr(1.0),
ResponseTimeMedian: Float64Ptr(2.0),
ResponseTime90th: Float64Ptr(3.0),
ResponseTime99th: Float64Ptr(4.0),
},
}

actual, err := client.GetDNSFirewallUserAnalytics(context.Background(), UserIdentifier("foo"), GetDNSFirewallUserAnalyticsParams{ClusterID: "12345", DNSFirewallUserAnalyticsOptions: DNSFirewallUserAnalyticsOptions{
Metrics: []string{
"queryCount",
"uncachedCount",
"staleCount",
"responseTimeAvg",
"responseTimeMedia",
"responseTime90th",
"responseTime99th",
},
Since: &since,
Until: &until,
}})

if assert.NoError(t, err) {
assert.Equal(t, want, actual)
}
}

func TestDNSFirewallUserAnalytics(t *testing.T) {
func TestDNSFirewallUserAnalytics_AccountLevel(t *testing.T) {
setup()
defer teardown()

Expand Down Expand Up @@ -53,20 +112,20 @@ func TestDNSFirewallUserAnalytics(t *testing.T) {
}`)
}

mux.HandleFunc("/user/dns_firewall/12345/dns_analytics/report", handler)
mux.HandleFunc("/accounts/"+testAccountID+"/dns_firewall/12345/dns_analytics/report", handler)
want := DNSFirewallAnalytics{
Totals: DNSFirewallAnalyticsMetrics{
QueryCount: int64Ptr(5),
UncachedCount: int64Ptr(6),
StaleCount: int64Ptr(7),
ResponseTimeAvg: float64Ptr(1.0),
ResponseTimeMedian: float64Ptr(2.0),
ResponseTime90th: float64Ptr(3.0),
ResponseTime99th: float64Ptr(4.0),
QueryCount: Int64Ptr(5),
UncachedCount: Int64Ptr(6),
StaleCount: Int64Ptr(7),
ResponseTimeAvg: Float64Ptr(1.0),
ResponseTimeMedian: Float64Ptr(2.0),
ResponseTime90th: Float64Ptr(3.0),
ResponseTime99th: Float64Ptr(4.0),
},
}

params := DNSFirewallUserAnalyticsOptions{
actual, err := client.GetDNSFirewallUserAnalytics(context.Background(), AccountIdentifier(testAccountID), GetDNSFirewallUserAnalyticsParams{ClusterID: "12345", DNSFirewallUserAnalyticsOptions: DNSFirewallUserAnalyticsOptions{
Metrics: []string{
"queryCount",
"uncachedCount",
Expand All @@ -78,8 +137,8 @@ func TestDNSFirewallUserAnalytics(t *testing.T) {
},
Since: &since,
Until: &until,
}
actual, err := client.DNSFirewallUserAnalytics(context.Background(), "12345", params)
}})

if assert.NoError(t, err) {
assert.Equal(t, want, actual)
}
Expand Down
5 changes: 5 additions & 0 deletions resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ func (rc *ResourceContainer) URLFragment() string {
if rc.Level == "" {
return rc.Identifier
}

if rc.Level == UserRouteLevel {
return "user"
}

return fmt.Sprintf("%s/%s", rc.Level, rc.Identifier)
}

Expand Down
7 changes: 5 additions & 2 deletions resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ func TestResourcURLFragment(t *testing.T) {
container *ResourceContainer
want string
}{
"account resource": {container: AccountIdentifier("foo"), want: "accounts/foo"},
"zone resource": {container: ZoneIdentifier("foo"), want: "zones/foo"},
"account resource": {container: AccountIdentifier("foo"), want: "accounts/foo"},
"zone resource": {container: ZoneIdentifier("foo"), want: "zones/foo"},
// this is pretty well deprecated in favour of `AccountIdentifier` but
// here for completeness.
"user level resource": {container: UserIdentifier("foo"), want: "user"},
"missing level resource": {container: &ResourceContainer{Level: "", Identifier: "foo"}, want: "foo"},
}

Expand Down