Skip to content

Commit

Permalink
Apply rubocop fixes to Workload Identity Federation code
Browse files Browse the repository at this point in the history
  • Loading branch information
rbclark committed Jun 27, 2022
1 parent eb96073 commit a010945
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 143 deletions.
201 changes: 106 additions & 95 deletions lib/googleauth/external_account.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

require 'time'
require "time"
require "googleauth/oauth2/sts_client"

module Google
# Module Auth provides classes that provide Google-specific authorization
# used to access Google APIs.
module Auth

# Authenticates requests using External Account credentials, such
# as those provided by the AWS provider.
class ExternalAccountCredentials
extend CredentialsLoader
attr_reader :project_id
Expand All @@ -46,8 +47,7 @@ def self.make_creds options = {}
)
end

# Reads the fields from the
# JSON key.
# Reads the required fields from the JSON.
def self.read_json_key json_key_io
json_key = MultiJson.load json_key_io.read
wanted = [
Expand All @@ -60,11 +60,14 @@ def self.read_json_key json_key_io
end
end

# This module handles the retrieval of credentials from Google Cloud
# by utilizing the AWS EC2 metadata service and then exchanging the
# credentials for a short-lived Google Cloud access token.
class AwsCredentials
AUTH_METADATA_KEY = :authorization
STS_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange"
STS_REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token"
IAM_SCOPE = ["https://www.googleapis.com/auth/iam"]
STS_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange".freeze
STS_REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token".freeze
IAM_SCOPE = ["https://www.googleapis.com/auth/iam"].freeze

def initialize options = {}
@audience = options[:audience]
Expand All @@ -80,46 +83,22 @@ def initialize options = {}
@regional_cred_verification_url = @credential_source["regional_cred_verification_url"]

@region = options[:region] || region(options)
@request_signer = AwsRequestSigner.new(@region)
@request_signer = AwsRequestSigner.new @region

@expiry = nil
@access_token = nil

@sts_client = Google::Auth::OAuth2::STSClient.new(token_exchange_endpoint: @token_url)
@sts_client = Google::Auth::OAuth2::STSClient.new token_exchange_endpoint: @token_url
end

def fetch_access_token! options = {}
credentials = fetch_security_credentials options

request_options = @request_signer.generate_signed_request(
credentials,
@regional_cred_verification_url.sub("{region}", @region),
"POST"
)

request_headers = request_options[:headers]
request_headers["x-goog-cloud-target-resource"] = @audience

aws_signed_request = {
headers: [],
method: request_options[:method],
url: request_options[:url]
}
aws_signed_request[:headers] = request_headers.keys.sort.map do |key|
{key: key, value: request_headers[key]}
end
response = exchange_token credentials

response = @sts_client.exchange_token(
audience: @audience,
grant_type: STS_GRANT_TYPE,
subject_token: uri_escape(aws_signed_request.to_json),
subject_token_type: @subject_token_type,
scopes: @service_account_impersonation_url ? IAM_SCOPE : @scope,
requested_token_type: STS_REQUESTED_TOKEN_TYPE
)
if(@service_account_impersonation_url)
impersonated_response = get_impersonated_access_token(response["access_token"])
@expiry = Time.parse(impersonated_response["expireTime"])
if @service_account_impersonation_url
impersonated_response = get_impersonated_access_token response["access_token"]
@expiry = Time.parse impersonated_response["expireTime"]
@access_token = impersonated_response["accessToken"]
else
# Extract the expiration time in seconds from the response and calculate the actual expiration time
Expand Down Expand Up @@ -155,15 +134,45 @@ def apply a_hash, opts = {}

private

def exchange_token credentials
request_options = @request_signer.generate_signed_request(
credentials,
@regional_cred_verification_url.sub("{region}", @region),
"POST"
)

request_headers = request_options[:headers]
request_headers["x-goog-cloud-target-resource"] = @audience

aws_signed_request = {
headers: [],
method: request_options[:method],
url: request_options[:url]
}

aws_signed_request[:headers] = request_headers.keys.sort.map do |key|
{ key: key, value: request_headers[key] }
end

@sts_client.exchange_token(
audience: @audience,
grant_type: STS_GRANT_TYPE,
subject_token: uri_escape(aws_signed_request.to_json),
subject_token_type: @subject_token_type,
scopes: @service_account_impersonation_url ? IAM_SCOPE : @scope,
requested_token_type: STS_REQUESTED_TOKEN_TYPE
)
end

def get_impersonated_access_token token, options = {}
c = options[:connection] || Faraday.default_connection

response = c.post(@service_account_impersonation_url) do |req|
response = c.post @service_account_impersonation_url do |req|
req.headers["Authorization"] = "Bearer #{token}"
req.headers["Content-Type"] = "application/json"
req.body = MultiJson.dump({
"scope": @scope
})
scope: @scope
})
end

if response.status != 200
Expand All @@ -173,11 +182,11 @@ def get_impersonated_access_token token, options = {}
MultiJson.load response.body
end

def uri_escape(string)
def uri_escape string
if string.nil?
nil
else
CGI.escape(string.encode('UTF-8')).gsub('+', '%20').gsub('%7E', '~')
CGI.escape(string.encode("UTF-8")).gsub("+", "%20").gsub("%7E", "~")
end
end

Expand All @@ -201,7 +210,7 @@ def fetch_security_credentials options = {}
role_name = fetch_metadata_role_name options
credentials = fetch_metadata_security_credentials role_name, options

return {
{
access_key_id: credentials["AccessKeyId"],
secret_access_key: credentials["SecretAccessKey"],
session_token: credentials["Token"]
Expand All @@ -224,21 +233,21 @@ def fetch_metadata_role_name options = {}
raise "Unable to determine the AWS role attached to the current AWS workload"
end

return response.body
response.body
end

# Retrieves the AWS security credentials required for signing AWS
# requests from the AWS metadata server.
def fetch_metadata_security_credentials role_name, options = {}
c = options[:connection] || Faraday.default_connection

response = c.get "#{@credential_source_url}/#{role_name}", {}, {"Content-Type": "application/json"}
response = c.get "#{@credential_source_url}/#{role_name}", {}, { "Content-Type": "application/json" }

unless response.success?
raise "Unable to fetch the AWS security credentials required for signing AWS requests"
end

return MultiJson.load response.body
MultiJson.load response.body
end

# Region may already be set, if it is then it can just be returned
Expand All @@ -247,7 +256,7 @@ def region options = {}
raise "region_url or region must be set for external account credentials" unless @region_url

c = options[:connection] || Faraday.default_connection
@region ||= c.get(@region_url).body[..-2]
@region ||= c.get(@region_url).body[0..-2]
end

@region
Expand Down Expand Up @@ -276,38 +285,32 @@ def initialize region_name
# method (str): The HTTP method used to call this API.
# Returns:
# Hash[str, str]: The AWS signed request dictionary object.
def generate_signed_request(aws_credentials, url, method, request_payload="")
def generate_signed_request aws_credentials, url, method, request_payload = ""
headers = {}

uri = URI.parse url

headers['host'] = uri.host

if !uri.hostname || uri.scheme != "https"
raise "Invalid AWS service URL"
end

service_name = uri.host.split(".").first

datetime = Time.now.utc.strftime("%Y%m%dT%H%M%SZ")
date = datetime[0,8]
datetime = Time.now.utc.strftime "%Y%m%dT%H%M%SZ"
date = datetime[0, 8]

headers['x-amz-date'] = datetime
headers['x-amz-security-token'] = aws_credentials[:session_token] if aws_credentials[:session_token]
headers["host"] = uri.host
headers["x-amz-date"] = datetime
headers["x-amz-security-token"] = aws_credentials[:session_token] if aws_credentials[:session_token]

content_sha256 = sha256_hexdigest(request_payload)
content_sha256 = sha256_hexdigest request_payload

canonical_req = canonical_request(method, uri, headers, content_sha256)
sts = string_to_sign(datetime, canonical_req, service_name)
canonical_req = canonical_request method, uri, headers, content_sha256
sts = string_to_sign datetime, canonical_req, service_name

# Authorization header requires everything else to be properly setup in order to be properly
# calculated.
headers['Authorization'] = [
"AWS4-HMAC-SHA256",
"Credential=#{credential(aws_credentials[:access_key_id], date, service_name)},",
"SignedHeaders=#{headers.keys.sort.join(';')},",
"Signature=#{signature(aws_credentials[:secret_access_key], date, sts, service_name)}"
].join(" ")
headers["Authorization"] = build_authorization_header headers, sts, aws_credentials, service_name, date

{
url: uri.to_s,
Expand All @@ -318,47 +321,55 @@ def generate_signed_request(aws_credentials, url, method, request_payload="")

private

def signature(secret_access_key, date, string_to_sign, service)
k_date = hmac("AWS4" + secret_access_key, date)
k_region = hmac(k_date, @region_name)
k_service = hmac(k_region, service)
k_credentials = hmac(k_service, 'aws4_request')
def build_authorization_header headers, sts, aws_credentials, service_name, date
[
"AWS4-HMAC-SHA256",
"Credential=#{credential aws_credentials[:access_key_id], date, service_name},",
"SignedHeaders=#{headers.keys.sort.join ';'},",
"Signature=#{signature aws_credentials[:secret_access_key], date, sts, service_name}"
].join(" ")
end

hexhmac(k_credentials, string_to_sign)
def signature secret_access_key, date, string_to_sign, service
k_date = hmac "AWS4#{secret_access_key}", date
k_region = hmac k_date, @region_name
k_service = hmac k_region, service
k_credentials = hmac k_service, "aws4_request"

hexhmac k_credentials, string_to_sign
end

def hmac(key, value)
OpenSSL::HMAC.digest(OpenSSL::Digest.new('sha256'), key, value)
def hmac key, value
OpenSSL::HMAC.digest OpenSSL::Digest.new("sha256"), key, value
end

def hexhmac(key, value)
OpenSSL::HMAC.hexdigest(OpenSSL::Digest.new('sha256'), key, value)
def hexhmac key, value
OpenSSL::HMAC.hexdigest OpenSSL::Digest.new("sha256"), key, value
end

def credential(access_key_id, date, service)
"#{access_key_id}/#{credential_scope(date, service)}"
def credential access_key_id, date, service
"#{access_key_id}/#{credential_scope date, service}"
end

def credential_scope(date, service)
def credential_scope date, service
[
date,
@region_name,
service,
'aws4_request',
].join('/')
"aws4_request"
].join("/")
end

def string_to_sign(datetime, canonical_request, service)

def string_to_sign datetime, canonical_request, service
[
'AWS4-HMAC-SHA256',
"AWS4-HMAC-SHA256",
datetime,
credential_scope(datetime[0,8], service),
credential_scope(datetime[0, 8], service),
sha256_hexdigest(canonical_request)
].join("\n")
end

def host(uri)
def host uri
# Handles known and unknown URI schemes; default_port nil when unknown.
if uri.default_port == uri.port
uri.host
Expand All @@ -367,37 +378,37 @@ def host(uri)
end
end

def canonical_request(http_method, uri, headers, content_sha256)
def canonical_request http_method, uri, headers, content_sha256
headers = headers.sort_by(&:first) # transforms to a sorted array of [key, value]

[
http_method,
uri.path.empty? ? '/' : uri.path,
build_canonical_querystring(uri.query || ''),
headers.map { |k,v| "#{k}:#{v}\n" }.join, # Canonical headers
uri.path.empty? ? "/" : uri.path,
build_canonical_querystring(uri.query || ""),
headers.map { |k, v| "#{k}:#{v}\n" }.join, # Canonical headers
headers.map(&:first).join(";"), # Signed headers
content_sha256
].join("\n")
end

def sha256_hexdigest(string)
OpenSSL::Digest::SHA256.hexdigest(string)
def sha256_hexdigest string
OpenSSL::Digest::SHA256.hexdigest string
end

# Generates the canonical query string given a raw query string.
# Logic is based on
# https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
# Code is from the AWS SDK for Ruby
# https://github.com/aws/aws-sdk-ruby/blob/0ac3d0a393ed216290bfb5f0383380376f6fb1f1/gems/aws-sigv4/lib/aws-sigv4/signer.rb#L532
def build_canonical_querystring(query)
params = query.split('&')
params = params.map { |p| p.match(/=/) ? p : p + '=' }
def build_canonical_querystring query
params = query.split "&"
params = params.map { |p| p.match(/=/) ? p : "#{p}=" }

params.each.with_index.sort do |a, b|
a, a_offset = a
b, b_offset = b
a_name, a_value = a.split('=')
b_name, b_value = b.split('=')
a_name, a_value = a.split "="
b_name, b_value = b.split "="
if a_name == b_name
if a_value == b_value
a_offset <=> b_offset
Expand All @@ -407,7 +418,7 @@ def build_canonical_querystring(query)
else
a_name <=> b_name
end
end.map(&:first).join('&')
end.map(&:first).join("&")
end
end
end
Expand Down

0 comments on commit a010945

Please sign in to comment.