New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add type annotations and refactor to use TokenDict type #6497
Changes from all commits
c489bd5
950cf95
600852c
f322981
9a56b4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
from enum import Enum | ||
from re import match | ||
from traceback import format_exc | ||
from typing import TYPE_CHECKING | ||
from typing import TYPE_CHECKING, Any, cast, Iterator, Optional | ||
|
||
from sqlalchemy import select, and_ | ||
from sqlalchemy.exc import IntegrityError | ||
|
@@ -27,6 +27,7 @@ | |
import rucio.core.rse | ||
from rucio.common import exception | ||
from rucio.common.config import config_get_bool | ||
from rucio.common.types import InternalAccount, AccountAttributesDict, AccountDict, AccountUsageModelDict, IdentityDict, UsageDict | ||
from rucio.core.vo import vo_exists | ||
from rucio.db.sqla import models | ||
from rucio.db.sqla.constants import AccountStatus, AccountType | ||
|
@@ -37,7 +38,7 @@ | |
|
||
|
||
@transactional_session | ||
def add_account(account, type_, email, *, session: "Session"): | ||
def add_account(account: InternalAccount, type_: AccountType, email: str, *, session: "Session") -> None: | ||
""" Add an account with the given account name and type. | ||
|
||
:param account: the name of the new account. | ||
|
@@ -63,7 +64,7 @@ def add_account(account, type_, email, *, session: "Session"): | |
|
||
|
||
@read_session | ||
def account_exists(account, *, session: "Session"): | ||
def account_exists(account: InternalAccount, *, session: "Session") -> bool: | ||
""" Checks to see if account exists and is active. | ||
|
||
:param account: Name of the account. | ||
|
@@ -81,7 +82,7 @@ def account_exists(account, *, session: "Session"): | |
|
||
|
||
@read_session | ||
def get_account(account, *, session: "Session"): | ||
def get_account(account: InternalAccount, *, session: "Session") -> models.Account: | ||
""" Returns an account for the given account name. | ||
|
||
:param account: the name of the account. | ||
|
@@ -102,7 +103,7 @@ def get_account(account, *, session: "Session"): | |
|
||
|
||
@transactional_session | ||
def del_account(account, *, session: "Session"): | ||
def del_account(account: InternalAccount, *, session: "Session") -> None: | ||
""" Disable an account with the given account name. | ||
|
||
:param account: the account name. | ||
|
@@ -115,15 +116,15 @@ def del_account(account, *, session: "Session"): | |
models.Account.status == AccountStatus.ACTIVE | ||
) | ||
try: | ||
account = session.execute(query).scalar_one() | ||
account_result = session.execute(query).scalar_one() | ||
except exc.NoResultFound: | ||
raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) | ||
|
||
account.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()}) | ||
account_result.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()}) | ||
|
||
|
||
@transactional_session | ||
def update_account(account, key, value, *, session: "Session"): | ||
def update_account(account: InternalAccount, key: str, value: Any, *, session: "Session") -> None: | ||
""" Update a property of an account. | ||
|
||
:param account: Name of the account. | ||
|
@@ -137,22 +138,22 @@ def update_account(account, key, value, *, session: "Session"): | |
models.Account.account == account | ||
) | ||
try: | ||
account = session.execute(query).scalar_one() | ||
query_result = session.execute(query).scalar_one() | ||
except exc.NoResultFound: | ||
raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) | ||
if key == 'status': | ||
if isinstance(value, str): | ||
value = AccountStatus[value] | ||
if value == AccountStatus.SUSPENDED: | ||
account.update({'status': value, 'suspended_at': datetime.utcnow()}) | ||
query_result.update({'status': value, 'suspended_at': datetime.utcnow()}) | ||
elif value == AccountStatus.ACTIVE: | ||
account.update({'status': value, 'suspended_at': None}) | ||
query_result.update({'status': value, 'suspended_at': None}) | ||
else: | ||
account.update({key: value}) | ||
query_result.update({key: value}) | ||
|
||
|
||
@stream_session | ||
def list_accounts(filter_=None, *, session: "Session"): | ||
def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session") -> Iterator[AccountDict]: | ||
""" Returns a list of all account names. | ||
|
||
:param filter_: Dictionary of attributes by which the input data should be filtered | ||
|
@@ -213,7 +214,7 @@ def list_accounts(filter_=None, *, session: "Session"): | |
|
||
|
||
@read_session | ||
def list_identities(account, *, session: "Session"): | ||
def list_identities(account: InternalAccount, *, session: "Session") -> list[IdentityDict]: | ||
""" | ||
List all identities on an account. | ||
|
||
|
@@ -244,11 +245,11 @@ def list_identities(account, *, session: "Session"): | |
).where( | ||
models.IdentityAccountAssociation.account == account | ||
) | ||
return [row._asdict() for row in session.execute(query)] | ||
return [cast(IdentityDict, row._asdict()) for row in session.execute(query)] | ||
Comment on lines
-247
to
+248
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not so much a comment for you, but more so for me: understand the necessity to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is what occurs when I don't use
See here for reference: https://github.com/rdimaio/rucio/actions/runs/8206991793/job/22447307300#step:9:10 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the $ cat --number test.py
1 #!/usr/bin/env python3
2
3 from typing import TypedDict, cast
4
5 class Struct(TypedDict):
6 foo: str
7 bar: int
8
9 def test1() -> Struct:
10 return {'foo': 'foo', 'bar': 0}
11
12 def test2() -> Struct:
13 return {'baz': None}
14
15 def test3() -> Struct:
16 return cast(Struct, {'baz': None})
$ mypy test.py
test.py:13: error: Missing keys ("foo", "bar") for TypedDict "Struct" [typeddict-item]
test.py:13: error: Extra key "baz" for TypedDict "Struct" [typeddict-unknown-key]
Found 2 errors in 1 file (checked 1 source file) Does that mean that we need to avoid the use of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of using
see test commit for reference: rdimaio@cff7fa2 And pyright report: https://github.com/rdimaio/rucio/actions/runs/8233368875/job/22512678528 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I’m afraid that I failed to verify that it will raise an error in case of a mismatch: $ cat --number test.py
1 #!/usr/bin/env python3
2
3 from typing import TypedDict
4
5 from sqlalchemy import select
6 from sqlalchemy.orm import Session
7
8 from rucio.db.sqla.models import Identity
9 from rucio.db.sqla.session import read_session
10
11
12 # This is wrong.
13 class IdentityDict(TypedDict):
14 foo: str
15 bar: int
16
17
18 @read_session
19 def list_identities(*, session: Session) -> list[IdentityDict]:
20 query = select(
21 Identity.identity
22 )
23 return [IdentityDict(**result._asdict()) for result in session.execute(query)]
24
25 l = list_identities()
$ pyright test.py
0 errors, 0 warnings, 0 informations There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 6. Incompatible TypedDict, passing parameters directly, no extra parameters
Pyright doesn't complain
Mypy complains
7. Incompatible TypedDict, passing parameters directly, extra parameters
Pyright complains
Mypy complains
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 8. Incompatible Pydantic model, passing parameters directly, static type checking
Pydantic complains here, as we saw earlier. From the point of view of the static type checkers: Pyright doesn't complain
Mypy complains
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 9. Incompatible dataclass, passing parameters directly
Pyright doesn't complain
Mypy complains (ambiguously)
Mypy complains even if I change 10. Incompatible dataclass, passing parameters directly
Pyright complains
Mypy complains
11. Incompatible TypedDict, passing parameters directly, adding explicitly types for the input arguments
Pyright complains
Mypy complains
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conclusion
In the long term, I think the best approach would be to use mypy+pydantic and use pydantic models for struct typing. For now, it seems to be fine to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Relevant thread #6615 |
||
|
||
|
||
@read_session | ||
def list_account_attributes(account, *, session: "Session"): | ||
def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[AccountAttributesDict]: | ||
""" | ||
Get all attributes defined for an account. | ||
|
||
|
@@ -274,11 +275,11 @@ def list_account_attributes(account, *, session: "Session"): | |
).where( | ||
models.AccountAttrAssociation.account == account | ||
) | ||
return [row._asdict() for row in session.execute(query)] | ||
return [cast(AccountAttributesDict, row._asdict()) for row in session.execute(query)] | ||
|
||
|
||
@read_session | ||
def has_account_attribute(account, key, *, session: "Session"): | ||
def has_account_attribute(account: InternalAccount, key: str, *, session: "Session") -> bool: | ||
""" | ||
Indicates whether the named key is present for the account. | ||
|
||
|
@@ -298,7 +299,7 @@ def has_account_attribute(account, key, *, session: "Session"): | |
|
||
|
||
@transactional_session | ||
def add_account_attribute(account, key, value, *, session: "Session"): | ||
def add_account_attribute(account: InternalAccount, key: str, value: Any, *, session: "Session") -> None: | ||
""" | ||
Add an attribute for the given account name. | ||
|
||
|
@@ -334,7 +335,7 @@ def add_account_attribute(account, key, value, *, session: "Session"): | |
|
||
|
||
@transactional_session | ||
def del_account_attribute(account, key, *, session: "Session"): | ||
def del_account_attribute(account: InternalAccount, key: str, *, session: "Session") -> None: | ||
""" | ||
Add an attribute for the given account name. | ||
|
||
|
@@ -355,7 +356,7 @@ def del_account_attribute(account, key, *, session: "Session"): | |
|
||
|
||
@read_session | ||
def get_usage(rse_id, account, *, session: "Session"): | ||
def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> UsageDict: | ||
""" | ||
Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist. | ||
|
||
|
@@ -373,13 +374,13 @@ def get_usage(rse_id, account, *, session: "Session"): | |
models.AccountUsage.account == account | ||
) | ||
try: | ||
return session.execute(query).one()._asdict() | ||
return cast(UsageDict, session.execute(query).one()._asdict()) | ||
except exc.NoResultFound: | ||
return {'bytes': 0, 'files': 0, 'updated_at': None} | ||
return UsageDict({'bytes': 0, 'files': 0, 'updated_at': None}) | ||
Comment on lines
-378
to
+379
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replied here: #6497 (comment) |
||
|
||
|
||
@read_session | ||
def get_all_rse_usages_per_account(account, *, session: "Session"): | ||
def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Session") -> list[AccountUsageModelDict]: | ||
""" | ||
Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist. | ||
|
||
|
@@ -394,20 +395,20 @@ def get_all_rse_usages_per_account(account, *, session: "Session"): | |
models.AccountUsage.account == account | ||
) | ||
try: | ||
return [result.to_dict() for result in session.execute(query).scalars()] | ||
return [cast(AccountUsageModelDict, result.to_dict()) for result in session.execute(query).scalars()] | ||
except exc.NoResultFound: | ||
return [] | ||
|
||
|
||
@read_session | ||
def get_usage_history(rse_id, account, *, session: "Session"): | ||
def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Session") -> list[UsageDict]: | ||
""" | ||
Returns historical values of the specified counter, or raises CounterNotFound if the counter does not exist. | ||
|
||
:param rse_id: The id of the RSE. | ||
:param account: The account name. | ||
:param session: The database session in use. | ||
:returns: A dictionary {'bytes', 'files', 'updated_at'} | ||
:returns: A list of dictionaries {'bytes', 'files', 'updated_at'} | ||
""" | ||
query = select( | ||
models.AccountUsageHistory.bytes, | ||
|
@@ -420,7 +421,6 @@ def get_usage_history(rse_id, account, *, session: "Session"): | |
models.AccountUsageHistory.updated_at | ||
) | ||
try: | ||
return [row._asdict() for row in session.execute(query)] | ||
return [cast(UsageDict, row._asdict()) for row in session.execute(query)] | ||
except exc.NoResultFound: | ||
raise exception.CounterNotFound('No usage can be found for account %s on RSE %s' % (account, rucio.core.rse.get_rse_name(rse_id=rse_id, session=session))) | ||
return [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use
isort
if you don’t want to do the work manually. 😄