Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] record dependencies when resolving across imports #11176

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 6 additions & 6 deletions crates/red_knot/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ pub trait SemanticDb: SourceDb {

fn symbol_table(&self, file_id: FileId) -> Arc<SymbolTable>;

fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type;

// mutations

fn add_module(&mut self, path: &Path) -> Option<(Module, Vec<Arc<ModuleData>>)>;

fn set_module_search_paths(&mut self, paths: Vec<ModuleSearchPath>);

fn infer_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId) -> Type;
}

pub trait Db: SemanticDb {}
Expand Down Expand Up @@ -159,6 +159,10 @@ pub(crate) mod tests {
path_to_module(self, path)
}

fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type {
infer_symbol_type(self, file_id, symbol_id)
}

fn symbol_table(&self, file_id: FileId) -> Arc<SymbolTable> {
symbol_table(self, file_id)
}
Expand All @@ -170,9 +174,5 @@ pub(crate) mod tests {
fn set_module_search_paths(&mut self, paths: Vec<ModuleSearchPath>) {
set_module_search_paths(self, paths);
}

fn infer_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId) -> Type {
infer_symbol_type(self, file_id, symbol_id)
}
}
}
8 changes: 4 additions & 4 deletions crates/red_knot/src/program/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ impl SemanticDb for Program {
symbol_table(self, file_id)
}

fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type {
infer_symbol_type(self, file_id, symbol_id)
}

// Mutations

fn add_module(&mut self, path: &Path) -> Option<(Module, Vec<Arc<ModuleData>>)> {
Expand All @@ -120,10 +124,6 @@ impl SemanticDb for Program {
fn set_module_search_paths(&mut self, paths: Vec<ModuleSearchPath>) {
set_module_search_paths(self, paths);
}

fn infer_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId) -> Type {
infer_symbol_type(self, file_id, symbol_id)
}
}

impl Db for Program {}
Expand Down
53 changes: 40 additions & 13 deletions crates/red_knot/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![allow(dead_code)]
use crate::ast_ids::NodeKey;
use crate::files::FileId;
use crate::module::ModuleName;
use crate::symbols::SymbolId;
use crate::{FxDashMap, FxIndexSet, Name};
use crate::{FxDashMap, FxHashSet, FxIndexSet, Name};
use ruff_index::{newtype_index, IndexVec};
use rustc_hash::FxHashMap;

Expand Down Expand Up @@ -49,17 +50,17 @@ pub struct TypeStore {
}

impl TypeStore {
pub fn remove_module(&mut self, file_id: FileId) {
pub fn remove_module(&self, file_id: FileId) {
self.modules.remove(&file_id);
}

pub fn cache_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId, ty: Type) {
pub fn cache_symbol_type(&self, file_id: FileId, symbol_id: SymbolId, ty: Type) {
self.add_or_get_module(file_id)
.symbol_types
.insert(symbol_id, ty);
}

pub fn cache_node_type(&mut self, file_id: FileId, node_key: NodeKey, ty: Type) {
pub fn cache_node_type(&self, file_id: FileId, node_key: NodeKey, ty: Type) {
self.add_or_get_module(file_id)
.node_types
.insert(node_key, ty);
Expand All @@ -79,7 +80,7 @@ impl TypeStore {
.copied()
}

fn add_or_get_module(&mut self, file_id: FileId) -> ModuleStoreRefMut {
fn add_or_get_module(&self, file_id: FileId) -> ModuleStoreRefMut {
self.modules
.entry(file_id)
.or_insert_with(|| ModuleTypeStore::new(file_id))
Expand All @@ -93,20 +94,20 @@ impl TypeStore {
self.modules.get(&file_id)
}

fn add_function(&mut self, file_id: FileId, name: &str) -> FunctionTypeId {
fn add_function(&self, file_id: FileId, name: &str) -> FunctionTypeId {
self.add_or_get_module(file_id).add_function(name)
}

fn add_class(&mut self, file_id: FileId, name: &str) -> ClassTypeId {
fn add_class(&self, file_id: FileId, name: &str) -> ClassTypeId {
self.add_or_get_module(file_id).add_class(name)
}

fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
fn add_union(&self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
self.add_or_get_module(file_id).add_union(elems)
}

fn add_intersection(
&mut self,
&self,
file_id: FileId,
positive: &[Type],
negative: &[Type],
Expand Down Expand Up @@ -142,6 +143,24 @@ impl TypeStore {
intersection_id: id.intersection_id,
}
}

fn record_symbol_dependency(&self, from: (FileId, SymbolId), to: (FileId, SymbolId)) {
let (from_file_id, from_symbol_id) = from;
self.add_or_get_module(from_file_id)
.symbol_dependencies
.entry(from_symbol_id)
.or_default()
.insert(to);
}

fn record_module_dependency(&self, from: (FileId, SymbolId), to: ModuleName) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this track the ModuleId instead of the ModuleName? For exampel, what if the module name remains unchanged, but it now resolves to a different module:

  • Before: namespace1/foo/bar.py and namespace2/foo/baz.py where foo has no __init__.py and the ModuleName is foo.baz
  • After: Add a __init__.py to namespace1/foo, now foo.baz no longer resolves.

Even without this limitation. Storing ModuleIds would be cheaper over ModuleNames. However, ModuleIds have the downside that they're less stable. So storing ModuleNames might, after all, be desired.

Copy link
Contributor Author

@carljm carljm Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case that you outline (same module name referring to a different file) is precisely why I think this needs to store a ModuleName. The dependency here is on "whatever this module name resolves to" -- if we have a dependency on foo.bar, and foo/bar.py doesn't change, but foo/bar/__init__.py is added so foo.bar now resolves there, we must invalidate the dependency on foo.bar. So the dependency is on the module name, not on the file ID.

Whether it could be on the module ID instead (i.e. type Module), or not, is unclear to me. That depends whether we guarantee stable 1:1 mapping between FileId and Module, or between ModuleName and Module (or neither?). If it's the latter (stable 1:1 mapping between ModuleName and Module), then I could store a dependency to Module here and it would be a little cheaper.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ModuleId mapping is stable between ModuleName and ModuleId. Although it is not guaranteed to be stable across runs (but that's not your concern, but something that the persistent caching layer must patch up).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, maybe it would work out to keep this dependency to FileId, too? As long as in the case where a file is shadowed by a different file for the same module name, we always remove the now-shadowed file, invalidating all its data?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might. However, I'm also happy to use ModuleId when we know that it always comes from a Module. It gives the id more semantic meaning.

let (from_file_id, from_symbol_id) = from;
self.add_or_get_module(from_file_id)
.module_dependencies
.entry(from_symbol_id)
.or_default()
.insert(to);
}
}

type ModuleStoreRef<'a> = dashmap::mapref::one::Ref<
Expand Down Expand Up @@ -265,6 +284,12 @@ struct ModuleTypeStore {
symbol_types: FxHashMap<SymbolId, Type>,
/// cached types of AST nodes in this module
node_types: FxHashMap<NodeKey, Type>,
// the inferred type for symbol K depends on the type of symbols in V
carljm marked this conversation as resolved.
Show resolved Hide resolved
symbol_dependencies: FxHashMap<SymbolId, FxHashSet<(FileId, SymbolId)>>,
// the inferred type for symbol K depends on the modules in V; this type of dependency is
// recorded when e.g. the target symbol doesn't exist in the module, so we can't record a
// dependency on a symbol, but if the module changes it could still change our resolution)
carljm marked this conversation as resolved.
Show resolved Hide resolved
module_dependencies: FxHashMap<SymbolId, FxHashSet<ModuleName>>,
}

impl ModuleTypeStore {
Expand All @@ -277,6 +302,8 @@ impl ModuleTypeStore {
intersections: IndexVec::default(),
symbol_types: FxHashMap::default(),
node_types: FxHashMap::default(),
symbol_dependencies: FxHashMap::default(),
module_dependencies: FxHashMap::default(),
}
}

Expand Down Expand Up @@ -462,7 +489,7 @@ mod tests {

#[test]
fn add_class() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let id = store.add_class(file_id, "C");
Expand All @@ -473,7 +500,7 @@ mod tests {

#[test]
fn add_function() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let id = store.add_function(file_id, "func");
Expand All @@ -484,7 +511,7 @@ mod tests {

#[test]
fn add_union() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1");
carljm marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -501,7 +528,7 @@ mod tests {

#[test]
fn add_intersection() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1");
Expand Down
39 changes: 31 additions & 8 deletions crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ use crate::types::Type;
use crate::FileId;
use ruff_python_ast::AstNode;

// TODO this should not take a &mut db, it should be a query, not a mutation. This means we'll need
// to use interior mutability in TypeStore instead, and avoid races in populating the cache.
#[tracing::instrument(level = "trace", skip(db))]
MichaReiser marked this conversation as resolved.
Show resolved Hide resolved
pub fn infer_symbol_type<Db>(db: &mut Db, file_id: FileId, symbol_id: SymbolId) -> Type
pub fn infer_symbol_type<Db>(db: &Db, file_id: FileId, symbol_id: SymbolId) -> Type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add some documentation explaining what the file_id and symbol_id paramters are?

where
Db: SemanticDb + HasJar<SemanticJar>,
{
Expand All @@ -36,15 +34,27 @@ where
// TODO relative imports
assert!(matches!(level, 0));
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
if let Some(module) = db.resolve_module(module_name) {
if let Some(module) = db.resolve_module(module_name.clone()) {
let remote_file_id = module.path(db).file();
let remote_symbols = db.symbol_table(remote_file_id);
if let Some(remote_symbol_id) = remote_symbols.root_symbol_id_by_name(name) {
// TODO integrate this into module and symbol-resolution APIs (requiring a
// "requester" argument) so that it doesn't have to be remembered
db.jar().type_store.record_symbol_dependency(
(file_id, symbol_id),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I already commented this on the other PR. I think it would be nice if we had a struct that combines file_id and symbol_id. I find that clearer than unnamed two-element tuples. We may then even change the infer_symbol_type method to take such an instance as the argument instead of a file id and symbol id pair.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, agreed, will push that new type in this PR.

(remote_file_id, remote_symbol_id),
);
db.infer_symbol_type(remote_file_id, remote_symbol_id)
} else {
db.jar()
.type_store
.record_module_dependency((file_id, symbol_id), module_name);
Type::Unknown
}
} else {
db.jar()
.type_store
.record_module_dependency((file_id, symbol_id), module_name);
Type::Unknown
}
}
Expand All @@ -60,7 +70,7 @@ where
let ast = parsed.ast();
let node = node_key.resolve_unwrap(ast.as_any_node_ref());

let store = &mut db.jar_mut().type_store;
let store = &db.jar().type_store;
let ty = Type::Class(store.add_class(file_id, &node.name.id));
store.cache_node_type(file_id, *node_key.erased(), ty);
ty
Expand All @@ -69,10 +79,9 @@ where
_ => todo!("other kinds of definitions"),
};

db.jar_mut()
db.jar()
.type_store
.cache_symbol_type(file_id, symbol_id, ty);
// TODO record dependencies
ty
}

Expand Down Expand Up @@ -112,7 +121,7 @@ mod tests {
fn follow_import_to_class() -> std::io::Result<()> {
let TestCase {
src,
mut db,
db,
temp_dir: _temp_dir,
} = create_test()?;

Expand All @@ -132,10 +141,24 @@ mod tests {

let ty = db.infer_symbol_type(a_file, d_sym);

let b_file = db
.resolve_module(ModuleName::new("b"))
.expect("module should be found")
.path(&db)
.file();
let b_syms = db.symbol_table(b_file);
let c_sym = b_syms
.root_symbol_id_by_name("C")
.expect("C symbol should be found");

let jar = HasJar::<SemanticJar>::jar(&db);

assert!(matches!(ty, Type::Class(_)));
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]");
assert_eq!(
jar.type_store.get_module(a_file).symbol_dependencies[&d_sym],
[(b_file, c_sym)].iter().copied().collect()
);
Ok(())
}
}