Skip to content

Commit

Permalink
add dependency tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
carljm committed Apr 27, 2024
1 parent 2d6978f commit a2ec705
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 29 deletions.
12 changes: 6 additions & 6 deletions crates/red_knot/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ pub trait SemanticDb: SourceDb {

fn symbol_table(&self, file_id: FileId) -> Arc<SymbolTable>;

fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type;

// mutations
fn path_to_module(&mut self, path: &Path) -> Option<Module>;

fn add_module(&mut self, path: &Path) -> Option<(Module, Vec<Arc<ModuleData>>)>;

fn set_module_search_paths(&mut self, paths: Vec<ModuleSearchPath>);

fn infer_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId) -> Type;
}

pub trait Db: SemanticDb {}
Expand Down Expand Up @@ -152,6 +152,10 @@ pub(crate) mod tests {
symbol_table(self, file_id)
}

fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type {
infer_symbol_type(self, file_id, symbol_id)
}

fn path_to_module(&mut self, path: &Path) -> Option<Module> {
path_to_module(self, path)
}
Expand All @@ -163,9 +167,5 @@ pub(crate) mod tests {
fn set_module_search_paths(&mut self, paths: Vec<ModuleSearchPath>) {
set_module_search_paths(self, paths);
}

fn infer_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId) -> Type {
infer_symbol_type(self, file_id, symbol_id)
}
}
}
8 changes: 4 additions & 4 deletions crates/red_knot/src/program/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ impl SemanticDb for Program {
symbol_table(self, file_id)
}

fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type {
infer_symbol_type(self, file_id, symbol_id)
}

// Mutations
fn path_to_module(&mut self, path: &Path) -> Option<Module> {
path_to_module(self, path)
Expand All @@ -115,10 +119,6 @@ impl SemanticDb for Program {
fn set_module_search_paths(&mut self, paths: Vec<ModuleSearchPath>) {
set_module_search_paths(self, paths);
}

fn infer_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId) -> Type {
infer_symbol_type(self, file_id, symbol_id)
}
}

impl Db for Program {}
Expand Down
53 changes: 40 additions & 13 deletions crates/red_knot/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![allow(dead_code)]
use crate::ast_ids::NodeKey;
use crate::files::FileId;
use crate::module::ModuleName;
use crate::symbols::SymbolId;
use crate::{FxDashMap, FxIndexSet, Name};
use crate::{FxDashMap, FxHashSet, FxIndexSet, Name};
use ruff_index::{newtype_index, IndexVec};
use rustc_hash::FxHashMap;

Expand Down Expand Up @@ -49,17 +50,17 @@ pub struct TypeStore {
}

impl TypeStore {
pub fn remove_module(&mut self, file_id: FileId) {
pub fn remove_module(&self, file_id: FileId) {
self.modules.remove(&file_id);
}

pub fn cache_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId, ty: Type) {
pub fn cache_symbol_type(&self, file_id: FileId, symbol_id: SymbolId, ty: Type) {
self.add_or_get_module(file_id)
.symbol_types
.insert(symbol_id, ty);
}

pub fn cache_node_type(&mut self, file_id: FileId, node_key: NodeKey, ty: Type) {
pub fn cache_node_type(&self, file_id: FileId, node_key: NodeKey, ty: Type) {
self.add_or_get_module(file_id)
.node_types
.insert(node_key, ty);
Expand All @@ -79,7 +80,7 @@ impl TypeStore {
.copied()
}

fn add_or_get_module(&mut self, file_id: FileId) -> ModuleStoreRefMut {
fn add_or_get_module(&self, file_id: FileId) -> ModuleStoreRefMut {
self.modules
.entry(file_id)
.or_insert_with(|| ModuleTypeStore::new(file_id))
Expand All @@ -93,20 +94,20 @@ impl TypeStore {
self.modules.get(&file_id)
}

fn add_function(&mut self, file_id: FileId, name: &str) -> FunctionTypeId {
fn add_function(&self, file_id: FileId, name: &str) -> FunctionTypeId {
self.add_or_get_module(file_id).add_function(name)
}

fn add_class(&mut self, file_id: FileId, name: &str) -> ClassTypeId {
fn add_class(&self, file_id: FileId, name: &str) -> ClassTypeId {
self.add_or_get_module(file_id).add_class(name)
}

fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
fn add_union(&self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
self.add_or_get_module(file_id).add_union(elems)
}

fn add_intersection(
&mut self,
&self,
file_id: FileId,
positive: &[Type],
negative: &[Type],
Expand Down Expand Up @@ -142,6 +143,24 @@ impl TypeStore {
intersection_id: id.intersection_id,
}
}

fn record_symbol_dependency(&self, from: (FileId, SymbolId), to: (FileId, SymbolId)) {
let (from_file_id, from_symbol_id) = from;
self.add_or_get_module(from_file_id)
.symbol_dependencies
.entry(from_symbol_id)
.or_default()
.insert(to);
}

fn record_module_dependency(&self, from: (FileId, SymbolId), to: ModuleName) {
let (from_file_id, from_symbol_id) = from;
self.add_or_get_module(from_file_id)
.module_dependencies
.entry(from_symbol_id)
.or_default()
.insert(to);
}
}

type ModuleStoreRef<'a> = dashmap::mapref::one::Ref<
Expand Down Expand Up @@ -265,6 +284,12 @@ struct ModuleTypeStore {
symbol_types: FxHashMap<SymbolId, Type>,
/// cached types of AST nodes in this module
node_types: FxHashMap<NodeKey, Type>,
// the inferred type for symbol K depends on the type of symbols in V
symbol_dependencies: FxHashMap<SymbolId, FxHashSet<(FileId, SymbolId)>>,
// the inferred type for symbol K depends on the modules in V; this type of dependency is
// recorded when e.g. the target symbol doesn't exist in the module, so we can't record a
// dependency on a symbol, but if the module changes it could still change our resolution)
module_dependencies: FxHashMap<SymbolId, FxHashSet<ModuleName>>,
}

impl ModuleTypeStore {
Expand All @@ -277,6 +302,8 @@ impl ModuleTypeStore {
intersections: IndexVec::default(),
symbol_types: FxHashMap::default(),
node_types: FxHashMap::default(),
symbol_dependencies: FxHashMap::default(),
module_dependencies: FxHashMap::default(),
}
}

Expand Down Expand Up @@ -462,7 +489,7 @@ mod tests {

#[test]
fn add_class() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let id = store.add_class(file_id, "C");
Expand All @@ -473,7 +500,7 @@ mod tests {

#[test]
fn add_function() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let id = store.add_function(file_id, "func");
Expand All @@ -484,7 +511,7 @@ mod tests {

#[test]
fn add_union() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1");
Expand All @@ -501,7 +528,7 @@ mod tests {

#[test]
fn add_intersection() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1");
Expand Down
37 changes: 31 additions & 6 deletions crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use ruff_python_ast::AstNode;
// TODO this should not take a &mut db, it should be a query, not a mutation. This means we'll need
// to use interior mutability in TypeStore instead, and avoid races in populating the cache.
#[tracing::instrument(level = "trace", skip(db))]
pub fn infer_symbol_type<Db>(db: &mut Db, file_id: FileId, symbol_id: SymbolId) -> Type
pub fn infer_symbol_type<Db>(db: &Db, file_id: FileId, symbol_id: SymbolId) -> Type
where
Db: SemanticDb + HasJar<SemanticJar>,
{
Expand All @@ -36,15 +36,27 @@ where
// TODO relative imports
assert!(matches!(level, 0));
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
if let Some(module) = db.resolve_module(module_name) {
if let Some(module) = db.resolve_module(module_name.clone()) {
let remote_file_id = module.path(db).file();
let remote_symbols = db.symbol_table(remote_file_id);
if let Some(remote_symbol_id) = remote_symbols.root_symbol_id_by_name(name) {
// TODO integrate this into module and symbol-resolution APIs (requiring a
// "requester" argument) so that it doesn't have to be remembered
db.jar().type_store.record_symbol_dependency(
(file_id, symbol_id),
(remote_file_id, remote_symbol_id),
);
db.infer_symbol_type(remote_file_id, remote_symbol_id)
} else {
db.jar()
.type_store
.record_module_dependency((file_id, symbol_id), module_name);
Type::Unknown
}
} else {
db.jar()
.type_store
.record_module_dependency((file_id, symbol_id), module_name);
Type::Unknown
}
}
Expand All @@ -60,7 +72,7 @@ where
let ast = parsed.ast();
let node = node_key.resolve_unwrap(ast.as_any_node_ref());

let store = &mut db.jar_mut().type_store;
let store = &db.jar().type_store;
let ty = Type::Class(store.add_class(file_id, &node.name.id));
store.cache_node_type(file_id, *node_key.erased(), ty);
ty
Expand All @@ -69,10 +81,9 @@ where
_ => todo!("other kinds of definitions"),
};

db.jar_mut()
db.jar()
.type_store
.cache_symbol_type(file_id, symbol_id, ty);
// TODO record dependencies
ty
}

Expand Down Expand Up @@ -112,7 +123,7 @@ mod tests {
fn follow_import_to_class() -> std::io::Result<()> {
let TestCase {
src,
mut db,
db,
temp_dir: _temp_dir,
} = create_test()?;

Expand All @@ -132,10 +143,24 @@ mod tests {

let ty = db.infer_symbol_type(a_file, d_sym);

let b_file = db
.resolve_module(ModuleName::new("b"))
.expect("module should be found")
.path(&db)
.file();
let b_syms = db.symbol_table(b_file);
let c_sym = b_syms
.root_symbol_id_by_name("C")
.expect("C symbol should be found");

let jar = HasJar::<SemanticJar>::jar(&db);

assert!(matches!(ty, Type::Class(_)));
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]");
assert_eq!(
jar.type_store.get_module(a_file).symbol_dependencies[&d_sym],
[(b_file, c_sym)].iter().copied().collect()
);
Ok(())
}
}

0 comments on commit a2ec705

Please sign in to comment.