diff --git a/resources/test/fixtures/flake8_pytest_style/PT006.py b/resources/test/fixtures/flake8_pytest_style/PT006.py index 0f9d6c734d895..9e5b229b80a58 100644 --- a/resources/test/fixtures/flake8_pytest_style/PT006.py +++ b/resources/test/fixtures/flake8_pytest_style/PT006.py @@ -39,3 +39,13 @@ def test_list(param1, param2): @pytest.mark.parametrize(["param1"], [1, 2, 3]) def test_list_one_elem(param1, param2): ... + + +@pytest.mark.parametrize([some_expr, another_expr], [1, 2, 3]) +def test_list_expressions(param1, param2): + ... + + +@pytest.mark.parametrize([some_expr, "param2"], [1, 2, 3]) +def test_list_mixed_expr_literal(param1, param2): + ... diff --git a/src/flake8_pytest_style/plugins/parametrize.rs b/src/flake8_pytest_style/plugins/parametrize.rs index 3bb2f9e63fede..02e94e59cebf7 100644 --- a/src/flake8_pytest_style/plugins/parametrize.rs +++ b/src/flake8_pytest_style/plugins/parametrize.rs @@ -16,6 +16,59 @@ fn get_parametrize_decorator<'a>(checker: &Checker, decorators: &'a [Expr]) -> O .find(|decorator| is_pytest_parametrize(decorator, checker)) } +fn elts_to_csv(elts: &[Expr], checker: &Checker) -> Option { + let all_literals = elts.iter().all(|e| { + matches!( + e.node, + ExprKind::Constant { + value: Constant::Str(_), + .. + } + ) + }); + + if !all_literals { + return None; + } + + let mut generator = SourceCodeGenerator::new( + checker.style.indentation(), + checker.style.quote(), + checker.style.line_ending(), + ); + + generator.unparse_expr( + &create_expr(ExprKind::Constant { + value: Constant::Str(elts.iter().fold(String::new(), |mut acc, elt| { + if let ExprKind::Constant { + value: Constant::Str(ref s), + .. + } = elt.node + { + if !acc.is_empty() { + acc.push(','); + } + acc.push_str(s); + } + acc + })), + kind: None, + }), + 0, + ); + + match generator.generate() { + Ok(s) => Some(s), + Err(e) => { + error!( + "Failed to generate CSV string from sequence of names: {}", + e + ); + None + } + } +} + /// PT006 fn check_names(checker: &mut Checker, expr: &Expr) { let names_type = checker.settings.flake8_pytest_style.parametrize_names_type; @@ -134,11 +187,60 @@ fn check_names(checker: &mut Checker, expr: &Expr) { if let Some(first) = elts.first() { handle_single_name(checker, expr, first); } - } else if names_type != types::ParametrizeNameType::Tuple { - checker.add_check(Check::new( - CheckKind::ParametrizeNamesWrongType(names_type), - Range::from_located(expr), - )); + } else { + match names_type { + types::ParametrizeNameType::Tuple => {} + types::ParametrizeNameType::List => { + let mut check = Check::new( + CheckKind::ParametrizeNamesWrongType(names_type), + Range::from_located(expr), + ); + if checker.patch(check.kind.code()) { + let mut generator = SourceCodeGenerator::new( + checker.style.indentation(), + checker.style.quote(), + checker.style.line_ending(), + ); + generator.unparse_expr( + &create_expr(ExprKind::List { + elts: elts.clone(), + ctx: ExprContext::Load, + }), + 0, + ); + match generator.generate() { + Ok(content) => { + check.amend(Fix::replacement( + content, + expr.location, + expr.end_location.unwrap(), + )); + } + Err(e) => error!( + "Failed to fix wrong name(s) type in \ + `@pytest.mark.parametrize`: {e}" + ), + }; + } + checker.add_check(check); + } + types::ParametrizeNameType::CSV => { + let mut check = Check::new( + CheckKind::ParametrizeNamesWrongType(names_type), + Range::from_located(expr), + ); + if checker.patch(check.kind.code()) { + if let Some(content) = elts_to_csv(elts, checker) { + check.amend(Fix::replacement( + content, + expr.location, + expr.end_location.unwrap(), + )); + } + } + checker.add_check(check); + } + } }; } ExprKind::List { elts, .. } => { @@ -146,11 +248,60 @@ fn check_names(checker: &mut Checker, expr: &Expr) { if let Some(first) = elts.first() { handle_single_name(checker, expr, first); } - } else if names_type != types::ParametrizeNameType::List { - checker.add_check(Check::new( - CheckKind::ParametrizeNamesWrongType(names_type), - Range::from_located(expr), - )); + } else { + match names_type { + types::ParametrizeNameType::List => {} + types::ParametrizeNameType::Tuple => { + let mut check = Check::new( + CheckKind::ParametrizeNamesWrongType(names_type), + Range::from_located(expr), + ); + if checker.patch(check.kind.code()) { + let mut generator = SourceCodeGenerator::new( + checker.style.indentation(), + checker.style.quote(), + checker.style.line_ending(), + ); + generator.unparse_expr( + &create_expr(ExprKind::Tuple { + elts: elts.clone(), + ctx: ExprContext::Load, + }), + 1, // so tuple is generated with parentheses + ); + match generator.generate() { + Ok(content) => { + check.amend(Fix::replacement( + content, + expr.location, + expr.end_location.unwrap(), + )); + } + Err(e) => error!( + "Failed to fix wrong name(s) type in \ + `@pytest.mark.parametrize`: {e}" + ), + }; + } + checker.add_check(check); + } + types::ParametrizeNameType::CSV => { + let mut check = Check::new( + CheckKind::ParametrizeNamesWrongType(names_type), + Range::from_located(expr), + ); + if checker.patch(check.kind.code()) { + if let Some(content) = elts_to_csv(elts, checker) { + check.amend(Fix::replacement( + content, + expr.location, + expr.end_location.unwrap(), + )); + } + } + checker.add_check(check); + } + } }; } _ => {} diff --git a/src/flake8_pytest_style/snapshots/ruff__flake8_pytest_style__tests__PT006_csv.snap b/src/flake8_pytest_style/snapshots/ruff__flake8_pytest_style__tests__PT006_csv.snap index 0f64c354bf02a..bee4145e5eda2 100644 --- a/src/flake8_pytest_style/snapshots/ruff__flake8_pytest_style__tests__PT006_csv.snap +++ b/src/flake8_pytest_style/snapshots/ruff__flake8_pytest_style__tests__PT006_csv.snap @@ -10,7 +10,14 @@ expression: checks end_location: row: 24 column: 45 - fix: ~ + fix: + content: "\"param1,param2\"" + location: + row: 24 + column: 25 + end_location: + row: 24 + column: 45 parent: ~ - kind: ParametrizeNamesWrongType: csv @@ -37,7 +44,14 @@ expression: checks end_location: row: 34 column: 45 - fix: ~ + fix: + content: "\"param1,param2\"" + location: + row: 34 + column: 25 + end_location: + row: 34 + column: 45 parent: ~ - kind: ParametrizeNamesWrongType: csv @@ -56,4 +70,24 @@ expression: checks row: 39 column: 35 parent: ~ +- kind: + ParametrizeNamesWrongType: csv + location: + row: 44 + column: 25 + end_location: + row: 44 + column: 50 + fix: ~ + parent: ~ +- kind: + ParametrizeNamesWrongType: csv + location: + row: 49 + column: 25 + end_location: + row: 49 + column: 46 + fix: ~ + parent: ~ diff --git a/src/flake8_pytest_style/snapshots/ruff__flake8_pytest_style__tests__PT006_default.snap b/src/flake8_pytest_style/snapshots/ruff__flake8_pytest_style__tests__PT006_default.snap index ae02f64d2b99a..d0cbb099ce3b5 100644 --- a/src/flake8_pytest_style/snapshots/ruff__flake8_pytest_style__tests__PT006_default.snap +++ b/src/flake8_pytest_style/snapshots/ruff__flake8_pytest_style__tests__PT006_default.snap @@ -78,7 +78,14 @@ expression: checks end_location: row: 34 column: 45 - fix: ~ + fix: + content: "(\"param1\", \"param2\")" + location: + row: 34 + column: 25 + end_location: + row: 34 + column: 45 parent: ~ - kind: ParametrizeNamesWrongType: csv @@ -97,4 +104,38 @@ expression: checks row: 39 column: 35 parent: ~ +- kind: + ParametrizeNamesWrongType: tuple + location: + row: 44 + column: 25 + end_location: + row: 44 + column: 50 + fix: + content: "(some_expr, another_expr)" + location: + row: 44 + column: 25 + end_location: + row: 44 + column: 50 + parent: ~ +- kind: + ParametrizeNamesWrongType: tuple + location: + row: 49 + column: 25 + end_location: + row: 49 + column: 46 + fix: + content: "(some_expr, \"param2\")" + location: + row: 49 + column: 25 + end_location: + row: 49 + column: 46 + parent: ~ diff --git a/src/flake8_pytest_style/snapshots/ruff__flake8_pytest_style__tests__PT006_list.snap b/src/flake8_pytest_style/snapshots/ruff__flake8_pytest_style__tests__PT006_list.snap index 66a007a2a8042..ac065f953dbca 100644 --- a/src/flake8_pytest_style/snapshots/ruff__flake8_pytest_style__tests__PT006_list.snap +++ b/src/flake8_pytest_style/snapshots/ruff__flake8_pytest_style__tests__PT006_list.snap @@ -61,7 +61,14 @@ expression: checks end_location: row: 24 column: 45 - fix: ~ + fix: + content: "[\"param1\", \"param2\"]" + location: + row: 24 + column: 25 + end_location: + row: 24 + column: 45 parent: ~ - kind: ParametrizeNamesWrongType: csv