From 6012c5d8ab13d18dca4811c20be8b745fd4cf77a Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Wed, 25 Oct 2023 14:23:33 +0530 Subject: [PATCH] Use source type to determine parser mode for formatting --- crates/ruff_python_formatter/src/cli.rs | 12 +++++++----- crates/ruff_python_formatter/src/comments/mod.rs | 9 +++++---- crates/ruff_python_formatter/src/lib.rs | 15 +++++++++------ crates/ruff_python_index/src/comment_ranges.rs | 6 ++++-- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/crates/ruff_python_formatter/src/cli.rs b/crates/ruff_python_formatter/src/cli.rs index e13fb85fe6831..09140af80689e 100644 --- a/crates/ruff_python_formatter/src/cli.rs +++ b/crates/ruff_python_formatter/src/cli.rs @@ -6,8 +6,9 @@ use anyhow::{format_err, Context, Result}; use clap::{command, Parser, ValueEnum}; use ruff_formatter::SourceCode; +use ruff_python_ast::PySourceType; use ruff_python_index::tokens_and_ranges; -use ruff_python_parser::{parse_ok_tokens, Mode}; +use ruff_python_parser::{parse_ok_tokens, AsMode}; use ruff_text_size::Ranged; use crate::comments::collect_comments; @@ -38,15 +39,16 @@ pub struct Cli { pub print_comments: bool, } -pub fn format_and_debug_print(source: &str, cli: &Cli, source_type: &Path) -> Result { - let (tokens, comment_ranges) = tokens_and_ranges(source) +pub fn format_and_debug_print(source: &str, cli: &Cli, source_path: &Path) -> Result { + let source_type = PySourceType::from(source_path); + let (tokens, comment_ranges) = tokens_and_ranges(source, source_type) .map_err(|err| format_err!("Source contains syntax errors {err:?}"))?; // Parse the AST. - let module = parse_ok_tokens(tokens, source, Mode::Module, "") + let module = parse_ok_tokens(tokens, source, source_type.as_mode(), "") .context("Syntax error in input")?; - let options = PyFormatOptions::from_extension(source_type); + let options = PyFormatOptions::from_extension(source_path); let source_code = SourceCode::new(source); let formatted = format_module_ast(&module, &comment_ranges, source, options) diff --git a/crates/ruff_python_formatter/src/comments/mod.rs b/crates/ruff_python_formatter/src/comments/mod.rs index cf51a37d43578..48733c652a90f 100644 --- a/crates/ruff_python_formatter/src/comments/mod.rs +++ b/crates/ruff_python_formatter/src/comments/mod.rs @@ -548,10 +548,10 @@ mod tests { use insta::assert_debug_snapshot; use ruff_formatter::SourceCode; - use ruff_python_ast::Mod; + use ruff_python_ast::{Mod, PySourceType}; use ruff_python_index::tokens_and_ranges; - use ruff_python_parser::{parse_ok_tokens, Mode}; + use ruff_python_parser::{parse_ok_tokens, AsMode}; use ruff_python_trivia::CommentRanges; use crate::comments::Comments; @@ -565,9 +565,10 @@ mod tests { impl<'a> CommentsTestCase<'a> { fn from_code(source: &'a str) -> Self { let source_code = SourceCode::new(source); + let source_type = PySourceType::Python; let (tokens, comment_ranges) = - tokens_and_ranges(source).expect("Expect source to be valid Python"); - let parsed = parse_ok_tokens(tokens, source, Mode::Module, "test.py") + tokens_and_ranges(source, source_type).expect("Expect source to be valid Python"); + let parsed = parse_ok_tokens(tokens, source, source_type.as_mode(), "test.py") .expect("Expect source to be valid Python"); CommentsTestCase { diff --git a/crates/ruff_python_formatter/src/lib.rs b/crates/ruff_python_formatter/src/lib.rs index 4271d3aa48197..7697789fe072c 100644 --- a/crates/ruff_python_formatter/src/lib.rs +++ b/crates/ruff_python_formatter/src/lib.rs @@ -7,7 +7,7 @@ use ruff_python_ast::AstNode; use ruff_python_ast::Mod; use ruff_python_index::tokens_and_ranges; use ruff_python_parser::lexer::LexicalError; -use ruff_python_parser::{parse_ok_tokens, Mode, ParseError}; +use ruff_python_parser::{parse_ok_tokens, AsMode, ParseError}; use ruff_python_trivia::CommentRanges; use ruff_source_file::Locator; @@ -130,8 +130,9 @@ pub fn format_module_source( source: &str, options: PyFormatOptions, ) -> Result { - let (tokens, comment_ranges) = tokens_and_ranges(source)?; - let module = parse_ok_tokens(tokens, source, Mode::Module, "")?; + let source_type = options.source_type(); + let (tokens, comment_ranges) = tokens_and_ranges(source, source_type)?; + let module = parse_ok_tokens(tokens, source, source_type.as_mode(), "")?; let formatted = format_module_ast(&module, &comment_ranges, source, options)?; Ok(formatted.print()?) } @@ -172,9 +173,10 @@ mod tests { use anyhow::Result; use insta::assert_snapshot; + use ruff_python_ast::PySourceType; use ruff_python_index::tokens_and_ranges; - use ruff_python_parser::{parse_ok_tokens, Mode}; + use ruff_python_parser::{parse_ok_tokens, AsMode}; use crate::{format_module_ast, format_module_source, PyFormatOptions}; @@ -213,11 +215,12 @@ def main() -> None: ] "#; - let (tokens, comment_ranges) = tokens_and_ranges(source).unwrap(); + let source_type = PySourceType::Python; + let (tokens, comment_ranges) = tokens_and_ranges(source, source_type).unwrap(); // Parse the AST. let source_path = "code_inline.py"; - let module = parse_ok_tokens(tokens, source, Mode::Module, source_path).unwrap(); + let module = parse_ok_tokens(tokens, source, source_type.as_mode(), source_path).unwrap(); let options = PyFormatOptions::from_extension(Path::new(source_path)); let formatted = format_module_ast(&module, &comment_ranges, source, options).unwrap(); diff --git a/crates/ruff_python_index/src/comment_ranges.rs b/crates/ruff_python_index/src/comment_ranges.rs index f1c4387b01f3c..e554bb8b00e63 100644 --- a/crates/ruff_python_index/src/comment_ranges.rs +++ b/crates/ruff_python_index/src/comment_ranges.rs @@ -1,7 +1,8 @@ use std::fmt::Debug; +use ruff_python_ast::PySourceType; use ruff_python_parser::lexer::{lex, LexicalError}; -use ruff_python_parser::{Mode, Tok}; +use ruff_python_parser::{AsMode, Tok}; use ruff_python_trivia::CommentRanges; use ruff_text_size::TextRange; @@ -25,11 +26,12 @@ impl CommentRangesBuilder { /// Helper method to lex and extract comment ranges pub fn tokens_and_ranges( source: &str, + source_type: PySourceType, ) -> Result<(Vec<(Tok, TextRange)>, CommentRanges), LexicalError> { let mut tokens = Vec::new(); let mut comment_ranges = CommentRangesBuilder::default(); - for result in lex(source, Mode::Module) { + for result in lex(source, source_type.as_mode()) { let (token, range) = result?; comment_ranges.visit_token(&token, range);