Skip to content

Commit

Permalink
Merge pull request #502 from python-rope/lieryan-autoimport-database-…
Browse files Browse the repository at this point in the history
…wrapper

Add autoimport database wrapper
  • Loading branch information
lieryan committed Jul 29, 2022
2 parents 604ca2b + c6d7564 commit 654e062
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 37 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

- #496,#497 Add MatMul operator to patchedast
- #495 Fix autoimport collection for compiled modules
- #501 Autoimport improvements
- #501, #502 Autoimport improvements

# Release 1.2.0

Expand Down
92 changes: 92 additions & 0 deletions rope/contrib/autoimport/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import List


class FinalQuery:
def __init__(self, query):
self._query = query


class Query:
def __init__(self, query: str, columns: List[str]):
self.query = query
self.columns = columns

def select(self, *columns: str):
if not (set(columns) <= set(self.columns)):
raise ValueError(
f"Unknown column names passed: {set(columns) - set(self.columns)}"
)

selected_columns = ", ".join(columns)
return FinalQuery(f"SELECT {selected_columns} FROM {self.query}")

def select_star(self):
return FinalQuery(f"SELECT * FROM {self.query}")

def where(self, where_clause: str):
return Query(
f"{self.query} WHERE {where_clause}",
columns=self.columns,
)

def insert_into(self) -> FinalQuery:
columns = ", ".join(self.columns)
placeholders = ", ".join(["?"] * len(self.columns))
return FinalQuery(
f"INSERT INTO {self.query}({columns}) VALUES ({placeholders})"
)

def drop_table(self) -> FinalQuery:
return FinalQuery(f"DROP TABLE {self.query}")

def delete_from(self) -> FinalQuery:
return FinalQuery(f"DELETE FROM {self.query}")


class Name:
table_name = "names"
columns = [
"name",
"module",
"package",
"source",
"type",
]

@classmethod
def create_table(self, connection):
names_table = (
"(name TEXT, module TEXT, package TEXT, source INTEGER, type INTEGER)"
)
connection.execute(f"CREATE TABLE IF NOT EXISTS names{names_table}")
connection.execute("CREATE INDEX IF NOT EXISTS name ON names(name)")
connection.execute("CREATE INDEX IF NOT EXISTS module ON names(module)")
connection.execute("CREATE INDEX IF NOT EXISTS package ON names(package)")

objects = Query(table_name, columns)

search_submodule_like = objects.where('module LIKE ("%." || ?)')
search_module_like = objects.where("module LIKE (?)")

import_assist = objects.where("name LIKE (? || '%')")

search_by_name_like = objects.where("name LIKE (?)")

delete_by_module_name = objects.where("module = ?").delete_from()


class Package:
table_name = "packages"
columns = [
"package",
"path",
]

@classmethod
def create_table(self, connection):
packages_table = "(package TEXT, path TEXT)"
connection.execute(f"CREATE TABLE IF NOT EXISTS packages{packages_table}")

objects = Query(table_name, columns)

delete_by_package_name = objects.where("package = ?").delete_from()
72 changes: 36 additions & 36 deletions rope/contrib/autoimport/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
sort_and_deduplicate_tuple,
)
from rope.refactor import importutils
from rope.contrib.autoimport import models


def get_future_names(
Expand Down Expand Up @@ -107,15 +108,8 @@ def __init__(self, project: Project, observe=True, underlined=False, memory=True
project.add_observer(observer)

def _setup_db(self):
packages_table = "(package TEXT, path TEXT)"
names_table = (
"(name TEXT, module TEXT, package TEXT, source INTEGER, type INTEGER)"
)
self.connection.execute(f"CREATE TABLE IF NOT EXISTS names{names_table}")
self.connection.execute(f"CREATE TABLE IF NOT EXISTS packages{packages_table}")
self.connection.execute("CREATE INDEX IF NOT EXISTS name ON names(name)")
self.connection.execute("CREATE INDEX IF NOT EXISTS module ON names(module)")
self.connection.execute("CREATE INDEX IF NOT EXISTS package ON names(package)")
models.Name.create_table(self.connection)
models.Package.create_table(self.connection)
self.connection.commit()

def import_assist(self, starting: str):
Expand All @@ -132,9 +126,8 @@ def import_assist(self, starting: str):
__________
Return a list of ``(name, module)`` tuples
"""
results = self.connection.execute(
"SELECT name, module, source FROM names WHERE name LIKE (?)",
(starting + "%",),
results = self._execute(
models.Name.import_assist.select("name", "module", "source"), (starting,)
).fetchall()
return sort_and_deduplicate_tuple(
results
Expand Down Expand Up @@ -195,8 +188,9 @@ def _search_name(
"""
if not exact_match:
name = name + "%" # Makes the query a starts_with query
for import_name, module, source, name_type in self.connection.execute(
"SELECT name, module, source, type FROM names WHERE name LIKE (?)", (name,)
for import_name, module, source, name_type in self._execute(
models.Name.search_by_name_like.select("name", "module", "source", "type"),
(name,),
):
yield (
SearchResult(
Expand All @@ -217,9 +211,8 @@ def _search_module(
"""
if not exact_match:
name = name + "%" # Makes the query a starts_with query
for module, source in self.connection.execute(
"SELECT module, source FROM names WHERE module LIKE (?)",
("%." + name,),
for module, source in self._execute(
models.Name.search_submodule_like.select("module", "source"), (name,)
):
parts = module.split(".")
import_name = parts[-1]
Expand All @@ -235,8 +228,8 @@ def _search_module(
NameType.Module.value,
)
)
for module, source in self.connection.execute(
"SELECT module, source FROM names WHERE module LIKE (?)", (name,)
for module, source in self._execute(
models.Name.search_module_like.select("module", "source"), (name,)
):
if "." in module:
continue
Expand All @@ -246,20 +239,20 @@ def _search_module(

def get_modules(self, name) -> List[str]:
"""Get the list of modules that have global `name`."""
results = self.connection.execute(
"SELECT module, source FROM names WHERE name LIKE (?)", (name,)
results = self._execute(
models.Name.search_by_name_like.select("module", "source"), (name,)
).fetchall()
return sort_and_deduplicate(results)

def get_all_names(self) -> List[str]:
"""Get the list of all cached global names."""
results = self.connection.execute("SELECT name FROM names").fetchall()
results = self._execute(models.Name.objects.select("name")).fetchall()
return results

def _dump_all(self) -> Tuple[List[Name], List[Package]]:
"""Dump the entire database."""
name_results = self.connection.execute("SELECT * FROM names").fetchall()
package_results = self.connection.execute("SELECT * FROM packages").fetchall()
name_results = self._execute(models.Name.objects.select_star()).fetchall()
package_results = self._execute(models.Package.objects.select_star()).fetchall()
return name_results, package_results

def generate_cache(
Expand All @@ -279,8 +272,8 @@ def generate_cache(
job_set = task_handle.create_jobset(
"Generating autoimport cache", len(resources)
)
self.connection.execute(
"delete from names where package = ?", (self.project_package.name,)
self._execute(
models.Package.delete_by_package_name, (self.project_package.name,)
)
futures = []
with ProcessPoolExecutor() as executor:
Expand Down Expand Up @@ -358,8 +351,8 @@ def close(self):
def get_name_locations(self, name):
"""Return a list of ``(resource, lineno)`` tuples."""
result = []
modules = self.connection.execute(
"SELECT module FROM names WHERE name LIKE (?)", (name,)
modules = self._execute(
models.Name.search_by_name_like.select("module"), (name,)
).fetchall()
for module in modules:
try:
Expand All @@ -385,7 +378,8 @@ def clear_cache(self):
regenerating global names.
"""
self.connection.execute("drop table names")
self._execute(models.Name.objects.drop_table())
self._execute(models.Package.objects.drop_table())
self._setup_db()
self.connection.commit()

Expand Down Expand Up @@ -430,7 +424,7 @@ def _moved(self, resource: Resource, newresource: Resource):
self.update_resource(newresource)

def _del_if_exist(self, module_name, commit: bool = True):
self.connection.execute("delete from names where module = ?", (module_name,))
self._execute(models.Name.delete_by_module_name, (module_name,))
if commit:
self.connection.commit()

Expand Down Expand Up @@ -458,13 +452,11 @@ def _get_available_packages(self) -> List[Package]:

def _add_packages(self, packages: List[Package]):
data = [(p.name, str(p.path)) for p in packages]
self.connection.executemany(
"INSERT INTO packages(package, path) VALUES (?, ?)", data
)
self._executemany(models.Package.objects.insert_into(), data)

def _get_packages_from_cache(self) -> List[str]:
existing: List[str] = list(
chain(*self.connection.execute("SELECT * FROM packages").fetchall())
chain(*self._execute(models.Package.objects.select_star()).fetchall())
)
existing.append(self.project_package.name)
return existing
Expand All @@ -482,8 +474,8 @@ def _add_names(self, names: Iterable[Name]):
self._add_name(name)

def _add_name(self, name: Name):
self.connection.execute(
"INSERT INTO names(name, module, package, source, type) VALUES (?, ?, ?, ?, ?)",
self._execute(
models.Name.objects.insert_into(),
(
name.name,
name.modname,
Expand Down Expand Up @@ -525,3 +517,11 @@ def _resource_to_module(
underlined,
resource_path.name == "__init__.py",
)

def _execute(self, query: models.FinalQuery, *args, **kwargs):
assert isinstance(query, models.FinalQuery)
return self.connection.execute(query._query, *args, **kwargs)

def _executemany(self, query: models.FinalQuery, *args, **kwargs):
assert isinstance(query, models.FinalQuery)
return self.connection.executemany(query._query, *args, **kwargs)
69 changes: 69 additions & 0 deletions ropetest/contrib/autoimport/modeltest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from unittest import TestCase

from rope.contrib.autoimport import models


class QueryTest(TestCase):
def test_select_non_existent_column(self):
with self.assertRaisesRegex(ValueError, """Unknown column names passed: {['"]doesnotexist['"]}"""):
models.Name.objects.select('doesnotexist')._query


class NameModelTest(TestCase):
def test_name_objects(self):
self.assertEqual(
models.Name.objects.select_star()._query,
"SELECT * FROM names",
)

def test_query_strings(self):
with self.subTest("objects"):
self.assertEqual(
models.Name.objects.select_star()._query,
'SELECT * FROM names',
)

with self.subTest("search_submodule_like"):
self.assertEqual(
models.Name.search_submodule_like.select_star()._query,
'SELECT * FROM names WHERE module LIKE ("%." || ?)',
)

with self.subTest("search_module_like"):
self.assertEqual(
models.Name.search_module_like.select_star()._query,
'SELECT * FROM names WHERE module LIKE (?)',
)

with self.subTest("import_assist"):
self.assertEqual(
models.Name.import_assist.select_star()._query,
"SELECT * FROM names WHERE name LIKE (? || '%')",
)

with self.subTest("search_by_name_like"):
self.assertEqual(
models.Name.search_by_name_like.select_star()._query,
'SELECT * FROM names WHERE name LIKE (?)',
)

with self.subTest("delete_by_module_name"):
self.assertEqual(
models.Name.delete_by_module_name._query,
'DELETE FROM names WHERE module = ?',
)


class PackageModelTest(TestCase):
def test_query_strings(self):
with self.subTest("objects"):
self.assertEqual(
models.Package.objects.select_star()._query,
'SELECT * FROM packages',
)

with self.subTest("delete_by_package_name"):
self.assertEqual(
models.Package.delete_by_package_name._query,
'DELETE FROM packages WHERE package = ?',
)

0 comments on commit 654e062

Please sign in to comment.