Skip to content

Commit

Permalink
Merge pull request #1313 from jacobbednarz/modernise-dns-firewall
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbednarz committed Jun 19, 2023
2 parents 2da520f + 37b9d25 commit 768e236
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 298 deletions.
7 changes: 7 additions & 0 deletions .changelog/1313.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
```release-note:breaking-change
virtualdns: remove support in favour of newer DNS firewall methods
```

```release-note:breaking-change
dns_firewall: modernise method signatures and conventions to align with the experimental client
```
96 changes: 60 additions & 36 deletions dns_firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package cloudflare
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)

var ErrMissingClusterID = errors.New("missing required cluster ID")

// DNSFirewallCluster represents a DNS Firewall configuration.
type DNSFirewallCluster struct {
ID string `json:"id,omitempty"`
Expand Down Expand Up @@ -43,9 +44,9 @@ type DNSFirewallAnalytics struct {

// DNSFirewallUserAnalyticsOptions represents range and dimension selection on analytics endpoint.
type DNSFirewallUserAnalyticsOptions struct {
Metrics []string
Since *time.Time
Until *time.Time
Metrics []string `url:"metrics,omitempty" del:","`
Since *time.Time `url:"since,omitempty"`
Until *time.Time `url:"until,omitempty"`
}

// dnsFirewallResponse represents a DNS Firewall response.
Expand All @@ -66,12 +67,42 @@ type dnsFirewallAnalyticsResponse struct {
Result DNSFirewallAnalytics `json:"result"`
}

type CreateDNSFirewallClusterParams struct {
Name string `json:"name"`
UpstreamIPs []string `json:"upstream_ips"`
DNSFirewallIPs []string `json:"dns_firewall_ips,omitempty"`
MinimumCacheTTL uint `json:"minimum_cache_ttl,omitempty"`
MaximumCacheTTL uint `json:"maximum_cache_ttl,omitempty"`
DeprecateAnyRequests bool `json:"deprecate_any_requests"`
}

type GetDNSFirewallClusterParams struct {
ClusterID string `json:"-"`
}

type UpdateDNSFirewallClusterParams struct {
ClusterID string `json:"-"`
Name string `json:"name"`
UpstreamIPs []string `json:"upstream_ips"`
DNSFirewallIPs []string `json:"dns_firewall_ips,omitempty"`
MinimumCacheTTL uint `json:"minimum_cache_ttl,omitempty"`
MaximumCacheTTL uint `json:"maximum_cache_ttl,omitempty"`
DeprecateAnyRequests bool `json:"deprecate_any_requests"`
}

type ListDNSFirewallClustersParams struct{}

type GetDNSFirewallUserAnalyticsParams struct {
ClusterID string `json:"-"`
DNSFirewallUserAnalyticsOptions
}

// CreateDNSFirewallCluster creates a new DNS Firewall cluster.
//
// API reference: https://api.cloudflare.com/#dns-firewall-create-dns-firewall-cluster
func (api *API) CreateDNSFirewallCluster(ctx context.Context, v DNSFirewallCluster) (*DNSFirewallCluster, error) {
uri := fmt.Sprintf("%s/dns_firewall", api.userBaseURL("/user"))
res, err := api.makeRequestContext(ctx, http.MethodPost, uri, v)
func (api *API) CreateDNSFirewallCluster(ctx context.Context, rc *ResourceContainer, params CreateDNSFirewallClusterParams) (*DNSFirewallCluster, error) {
uri := fmt.Sprintf("/%s/dns_firewall", rc.URLFragment())
res, err := api.makeRequestContext(ctx, http.MethodPost, uri, params)
if err != nil {
return nil, err
}
Expand All @@ -85,11 +116,15 @@ func (api *API) CreateDNSFirewallCluster(ctx context.Context, v DNSFirewallClust
return response.Result, nil
}

// DNSFirewallCluster fetches a single DNS Firewall cluster.
// GetDNSFirewallCluster fetches a single DNS Firewall cluster.
//
// API reference: https://api.cloudflare.com/#dns-firewall-dns-firewall-cluster-details
func (api *API) DNSFirewallCluster(ctx context.Context, clusterID string) (*DNSFirewallCluster, error) {
uri := fmt.Sprintf("%s/dns_firewall/%s", api.userBaseURL("/user"), clusterID)
func (api *API) GetDNSFirewallCluster(ctx context.Context, rc *ResourceContainer, params GetDNSFirewallClusterParams) (*DNSFirewallCluster, error) {
if params.ClusterID == "" {
return &DNSFirewallCluster{}, ErrMissingClusterID
}

uri := fmt.Sprintf("/%s/dns_firewall/%s", rc.URLFragment(), params.ClusterID)
res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return nil, err
Expand All @@ -107,8 +142,8 @@ func (api *API) DNSFirewallCluster(ctx context.Context, clusterID string) (*DNSF
// ListDNSFirewallClusters lists the DNS Firewall clusters associated with an account.
//
// API reference: https://api.cloudflare.com/#dns-firewall-list-dns-firewall-clusters
func (api *API) ListDNSFirewallClusters(ctx context.Context) ([]*DNSFirewallCluster, error) {
uri := fmt.Sprintf("%s/dns_firewall", api.userBaseURL("/user"))
func (api *API) ListDNSFirewallClusters(ctx context.Context, rc *ResourceContainer, params ListDNSFirewallClustersParams) ([]*DNSFirewallCluster, error) {
uri := fmt.Sprintf("/%s/dns_firewall", rc.URLFragment())
res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return nil, err
Expand All @@ -126,9 +161,13 @@ func (api *API) ListDNSFirewallClusters(ctx context.Context) ([]*DNSFirewallClus
// UpdateDNSFirewallCluster updates a DNS Firewall cluster.
//
// API reference: https://api.cloudflare.com/#dns-firewall-update-dns-firewall-cluster
func (api *API) UpdateDNSFirewallCluster(ctx context.Context, clusterID string, vv DNSFirewallCluster) error {
uri := fmt.Sprintf("%s/dns_firewall/%s", api.userBaseURL("/user"), clusterID)
res, err := api.makeRequestContext(ctx, http.MethodPatch, uri, vv)
func (api *API) UpdateDNSFirewallCluster(ctx context.Context, rc *ResourceContainer, params UpdateDNSFirewallClusterParams) error {
if params.ClusterID == "" {
return ErrMissingClusterID
}

uri := fmt.Sprintf("/%s/dns_firewall/%s", rc.URLFragment(), params.ClusterID)
res, err := api.makeRequestContext(ctx, http.MethodPatch, uri, params)
if err != nil {
return err
}
Expand All @@ -146,8 +185,8 @@ func (api *API) UpdateDNSFirewallCluster(ctx context.Context, clusterID string,
// undone, and will stop all traffic to that cluster.
//
// API reference: https://api.cloudflare.com/#dns-firewall-delete-dns-firewall-cluster
func (api *API) DeleteDNSFirewallCluster(ctx context.Context, clusterID string) error {
uri := fmt.Sprintf("%s/dns_firewall/%s", api.userBaseURL("/user"), clusterID)
func (api *API) DeleteDNSFirewallCluster(ctx context.Context, rc *ResourceContainer, clusterID string) error {
uri := fmt.Sprintf("/%s/dns_firewall/%s", rc.URLFragment(), clusterID)
res, err := api.makeRequestContext(ctx, http.MethodDelete, uri, nil)
if err != nil {
return err
Expand All @@ -162,24 +201,9 @@ func (api *API) DeleteDNSFirewallCluster(ctx context.Context, clusterID string)
return nil
}

// encode encodes non-nil fields into URL encoded form.
func (o DNSFirewallUserAnalyticsOptions) encode() string {
v := url.Values{}
if o.Since != nil {
v.Set("since", o.Since.UTC().Format(time.RFC3339))
}
if o.Until != nil {
v.Set("until", o.Until.UTC().Format(time.RFC3339))
}
if o.Metrics != nil {
v.Set("metrics", strings.Join(o.Metrics, ","))
}
return v.Encode()
}

// DNSFirewallUserAnalytics retrieves analytics report for a specified dimension and time range.
func (api *API) DNSFirewallUserAnalytics(ctx context.Context, clusterID string, o DNSFirewallUserAnalyticsOptions) (DNSFirewallAnalytics, error) {
uri := fmt.Sprintf("%s/dns_firewall/%s/dns_analytics/report?%s", api.userBaseURL("/user"), clusterID, o.encode())
// GetDNSFirewallUserAnalytics retrieves analytics report for a specified dimension and time range.
func (api *API) GetDNSFirewallUserAnalytics(ctx context.Context, rc *ResourceContainer, params GetDNSFirewallUserAnalyticsParams) (DNSFirewallAnalytics, error) {
uri := buildURI(fmt.Sprintf("/%s/dns_firewall/%s/dns_analytics/report", rc.URLFragment(), params.ClusterID), params.DNSFirewallUserAnalyticsOptions)
res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return DNSFirewallAnalytics{}, err
Expand Down
93 changes: 76 additions & 17 deletions dns_firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,74 @@ import (
"github.com/stretchr/testify/assert"
)

func float64Ptr(v float64) *float64 {
return &v
}
func TestDNSFirewallUserAnalytics_UserLevel(t *testing.T) {
setup()
defer teardown()

now := time.Now().UTC()
since := now.Add(-1 * time.Hour)
until := now

handler := func(w http.ResponseWriter, r *http.Request) {
expectedMetrics := "queryCount,uncachedCount,staleCount,responseTimeAvg,responseTimeMedia,responseTime90th,responseTime99th"

assert.Equal(t, http.MethodGet, r.Method, "Expected method 'GET'")
assert.Equal(t, expectedMetrics, r.URL.Query().Get("metrics"), "Expected many metrics in URL parameter")
assert.Equal(t, since.Format(time.RFC3339), r.URL.Query().Get("since"), "Expected since parameter in URL")
assert.Equal(t, until.Format(time.RFC3339), r.URL.Query().Get("until"), "Expected until parameter in URL")

w.Header().Set("content-type", "application/json")
fmt.Fprint(w, `{
"result": {
"totals":{
"queryCount": 5,
"uncachedCount":6,
"staleCount":7,
"responseTimeAvg":1.0,
"responseTimeMedian":2.0,
"responseTime90th":3.0,
"responseTime99th":4.0
}
},
"success": true,
"errors": null,
"messages": null
}`)
}

func int64Ptr(v int64) *int64 {
return &v
mux.HandleFunc("/user/dns_firewall/12345/dns_analytics/report", handler)
want := DNSFirewallAnalytics{
Totals: DNSFirewallAnalyticsMetrics{
QueryCount: Int64Ptr(5),
UncachedCount: Int64Ptr(6),
StaleCount: Int64Ptr(7),
ResponseTimeAvg: Float64Ptr(1.0),
ResponseTimeMedian: Float64Ptr(2.0),
ResponseTime90th: Float64Ptr(3.0),
ResponseTime99th: Float64Ptr(4.0),
},
}

actual, err := client.GetDNSFirewallUserAnalytics(context.Background(), UserIdentifier("foo"), GetDNSFirewallUserAnalyticsParams{ClusterID: "12345", DNSFirewallUserAnalyticsOptions: DNSFirewallUserAnalyticsOptions{
Metrics: []string{
"queryCount",
"uncachedCount",
"staleCount",
"responseTimeAvg",
"responseTimeMedia",
"responseTime90th",
"responseTime99th",
},
Since: &since,
Until: &until,
}})

if assert.NoError(t, err) {
assert.Equal(t, want, actual)
}
}

func TestDNSFirewallUserAnalytics(t *testing.T) {
func TestDNSFirewallUserAnalytics_AccountLevel(t *testing.T) {
setup()
defer teardown()

Expand Down Expand Up @@ -53,20 +112,20 @@ func TestDNSFirewallUserAnalytics(t *testing.T) {
}`)
}

mux.HandleFunc("/user/dns_firewall/12345/dns_analytics/report", handler)
mux.HandleFunc("/accounts/"+testAccountID+"/dns_firewall/12345/dns_analytics/report", handler)
want := DNSFirewallAnalytics{
Totals: DNSFirewallAnalyticsMetrics{
QueryCount: int64Ptr(5),
UncachedCount: int64Ptr(6),
StaleCount: int64Ptr(7),
ResponseTimeAvg: float64Ptr(1.0),
ResponseTimeMedian: float64Ptr(2.0),
ResponseTime90th: float64Ptr(3.0),
ResponseTime99th: float64Ptr(4.0),
QueryCount: Int64Ptr(5),
UncachedCount: Int64Ptr(6),
StaleCount: Int64Ptr(7),
ResponseTimeAvg: Float64Ptr(1.0),
ResponseTimeMedian: Float64Ptr(2.0),
ResponseTime90th: Float64Ptr(3.0),
ResponseTime99th: Float64Ptr(4.0),
},
}

params := DNSFirewallUserAnalyticsOptions{
actual, err := client.GetDNSFirewallUserAnalytics(context.Background(), AccountIdentifier(testAccountID), GetDNSFirewallUserAnalyticsParams{ClusterID: "12345", DNSFirewallUserAnalyticsOptions: DNSFirewallUserAnalyticsOptions{
Metrics: []string{
"queryCount",
"uncachedCount",
Expand All @@ -78,8 +137,8 @@ func TestDNSFirewallUserAnalytics(t *testing.T) {
},
Since: &since,
Until: &until,
}
actual, err := client.DNSFirewallUserAnalytics(context.Background(), "12345", params)
}})

if assert.NoError(t, err) {
assert.Equal(t, want, actual)
}
Expand Down
5 changes: 5 additions & 0 deletions resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ func (rc *ResourceContainer) URLFragment() string {
if rc.Level == "" {
return rc.Identifier
}

if rc.Level == UserRouteLevel {
return "user"
}

return fmt.Sprintf("%s/%s", rc.Level, rc.Identifier)
}

Expand Down
7 changes: 5 additions & 2 deletions resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ func TestResourcURLFragment(t *testing.T) {
container *ResourceContainer
want string
}{
"account resource": {container: AccountIdentifier("foo"), want: "accounts/foo"},
"zone resource": {container: ZoneIdentifier("foo"), want: "zones/foo"},
"account resource": {container: AccountIdentifier("foo"), want: "accounts/foo"},
"zone resource": {container: ZoneIdentifier("foo"), want: "zones/foo"},
// this is pretty well deprecated in favour of `AccountIdentifier` but
// here for completeness.
"user level resource": {container: UserIdentifier("foo"), want: "user"},
"missing level resource": {container: &ResourceContainer{Level: "", Identifier: "foo"}, want: "foo"},
}

Expand Down

0 comments on commit 768e236

Please sign in to comment.