diff --git a/lib/rucio/core/authentication.py b/lib/rucio/core/authentication.py index e8a0d6c7e5..0ddd81ac88 100644 --- a/lib/rucio/core/authentication.py +++ b/lib/rucio/core/authentication.py @@ -220,15 +220,20 @@ def get_auth_token_ssh(account, signature, appid, ip=None, *, session: "Session" The token lifetime is 1 hour. :param account: Account identifier as a string. - :param signature: Response to server challenge signed with SSH private key as string. + :param signature: Response to server challenge signed with SSH private key as a base64 encoded string. :param appid: The application identifier as a string. :param ip: IP address of the client as a string. :param session: The database session in use. :returns: A dict with token and expires_at entries. """ - if not isinstance(signature, bytes): - signature = signature.encode() + + # decode the signature which must come in base64 encoded + try: + signature += '=' * ((4 - len(signature) % 4) % 4) # adding required padding + signature = b64decode(signature) + except TypeError: + raise CannotAuthenticate(f'Cannot authenticate to account {account} with malformed signature') # Make sure the account exists if not account_exists(account, session=session): diff --git a/lib/rucio/web/rest/flaskapi/v1/auth.py b/lib/rucio/web/rest/flaskapi/v1/auth.py index 70677b07a1..39915aa9c8 100644 --- a/lib/rucio/web/rest/flaskapi/v1/auth.py +++ b/lib/rucio/web/rest/flaskapi/v1/auth.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import base64 import json import logging import time @@ -1215,18 +1214,6 @@ def get(self): appid = request.headers.get('X-Rucio-AppID', default='unknown') ip = request.headers.get('X-Forwarded-For', default=request.remote_addr) - # decode the signature which must come in base64 encoded - try: - signature += '=' * ((4 - len(signature) % 4) % 4) # adding required padding - signature = base64.b64decode(signature) - except TypeError: - return generate_http_error_flask( - status_code=401, - exc=CannotAuthenticate.__name__, - exc_msg=f'Cannot authenticate to account {account} with malformed signature', - headers=headers - ) - try: result = get_auth_token_ssh(account, signature, appid, ip, vo=vo) except AccessDenied: diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 9491c4632c..fbb98fb644 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import base64 import datetime import time @@ -125,7 +124,7 @@ def test_get_auth_token_ssh_success(self, vo, root_account): challenge_token = get_ssh_challenge_token(account='root', appid='test', ip='127.0.0.1', vo=vo).get('token') - signature = base64.b64decode(ssh_sign(PRIVATE_KEY, challenge_token)) + signature = ssh_sign(PRIVATE_KEY, challenge_token) result = get_auth_token_ssh(account='root', signature=signature, appid='test', ip='127.0.0.1', vo=vo) @@ -159,8 +158,7 @@ def test_invalid_padding(self, vo, root_account): challenge_token = get_ssh_challenge_token(account='root', appid='test', ip='127.0.0.1', vo=vo).get('token') - ssh_sign_string = ssh_sign(PRIVATE_KEY, challenge_token) - signature = base64.b64decode(ssh_sign_string) + signature = ssh_sign(PRIVATE_KEY, challenge_token) result = get_auth_token_ssh(account='root', signature=signature, appid='test', ip='127.0.0.1', vo=vo) assert result is not None