Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 88 additions & 2 deletions pyrefly/lib/state/lsp/pytest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
*/

use pyrefly_build::handle::Handle;
use pyrefly_python::module_name::ModuleName;
use pyrefly_python::module_name::ModuleNameWithKind;
use pyrefly_python::module_path::ModulePath;
use pyrefly_python::module_path::ModulePathDetails;
use pyrefly_python::symbol_kind::SymbolKind;
use ruff_python_ast::AnyNodeRef;
use ruff_python_ast::Identifier;
Expand All @@ -16,14 +20,61 @@ use vec1::Vec1;
use super::DefinitionMetadata;
use super::FindDefinitionItemWithDocstring;
use crate::state::pytest::find_pytest_fixture_definitions_for_parameter;
use crate::state::pytest::find_pytest_fixture_definitions_in_module;
use crate::state::pytest::find_pytest_fixture_parameter_references;
use crate::state::pytest::is_pytest_fixture_parameter_context;
use crate::state::state::Transaction;

pub(crate) fn pytest_conftest_handles(
transaction: &Transaction<'_>,
handle: &Handle,
) -> Vec<Handle> {
let module_path = handle.path();
let Some(mut dir) = module_path.as_path().parent() else {
return Vec::new();
};
let root = module_path
.root_of(handle.module())
.unwrap_or_else(|| dir.to_path_buf());
let is_memory = matches!(module_path.details(), ModulePathDetails::Memory(_));
let mut conftest_paths = Vec::new();
loop {
let conftest_pyi = dir.join("conftest.pyi");
let conftest_py = dir.join("conftest.py");
if is_memory {
conftest_paths.push(ModulePath::memory(conftest_pyi.clone()));
conftest_paths.push(ModulePath::memory(conftest_py.clone()));
} else {
if conftest_pyi.exists() {
conftest_paths.push(ModulePath::filesystem(conftest_pyi));
}
if conftest_py.exists() {
conftest_paths.push(ModulePath::filesystem(conftest_py));
}
}
if dir == root {
break;
}
let Some(parent) = dir.parent() else {
break;
};
dir = parent;
}
let mut handles = Vec::new();
for path in conftest_paths {
let config = transaction
.config_finder()
.python_file(ModuleNameWithKind::guaranteed(ModuleName::unknown()), &path);
handles.push(config.handle_from_module_path(path));
}
handles
}

impl<'a> Transaction<'a> {
/// Resolve a pytest fixture parameter to the fixture functions that can provide it.
///
/// This runs during definition lookup. The common non-pytest path is cheap because we first
/// ask bindings for pytest metadata, which is absent in modules that do not import pytest.
/// require either a pytest fixture function or a test-named function.
pub(super) fn pytest_fixture_definitions_for_parameter(
&self,
handle: &Handle,
Expand All @@ -32,14 +83,17 @@ impl<'a> Transaction<'a> {
) -> Option<Vec1<FindDefinitionItemWithDocstring>> {
let mod_module = self.get_ast(handle)?;
let bindings = self.get_bindings(handle)?;
if !is_pytest_fixture_parameter_context(&bindings, covering_nodes) {
return None;
}
let matches = find_pytest_fixture_definitions_for_parameter(
mod_module.as_ref(),
&bindings,
identifier,
covering_nodes,
);
let module_info = self.get_module_info(handle)?;
let definitions = matches
let mut definitions: Vec<_> = matches
.into_iter()
.map(|fixture| FindDefinitionItemWithDocstring {
metadata: DefinitionMetadata::Variable(Some(SymbolKind::Function)),
Expand All @@ -49,6 +103,38 @@ impl<'a> Transaction<'a> {
display_name: Some(fixture.name.as_str().to_owned()),
})
.collect();
if definitions.is_empty() {
for conftest_handle in pytest_conftest_handles(self, handle) {
let Some(conftest_ast) = self.get_ast(&conftest_handle) else {
continue;
};
let Some(conftest_bindings) = self.get_bindings(&conftest_handle) else {
continue;
};
let Some(conftest_module_info) = self.get_module_info(&conftest_handle) else {
continue;
};
definitions.extend(
find_pytest_fixture_definitions_in_module(
conftest_ast.as_ref(),
&conftest_bindings,
identifier.id(),
None,
)
.into_iter()
.map(|fixture| FindDefinitionItemWithDocstring {
metadata: DefinitionMetadata::Variable(Some(SymbolKind::Function)),
definition_range: fixture.range,
module: conftest_module_info.clone(),
docstring_range: fixture.docstring_range,
display_name: Some(fixture.name.as_str().to_owned()),
}),
);
if !definitions.is_empty() {
break;
}
}
}
Vec1::try_from_vec(definitions).ok()
}

Expand Down
49 changes: 2 additions & 47 deletions pyrefly/lib/state/lsp/quick_fixes/pytest_fixture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ use std::collections::HashSet;

use dupe::Dupe;
use pyrefly_build::handle::Handle;
use pyrefly_python::module_name::ModuleName;
use pyrefly_python::module_name::ModuleNameWithKind;
use pyrefly_python::module_path::ModulePath;
use pyrefly_python::module_path::ModulePathDetails;
use pyrefly_python::short_identifier::ShortIdentifier;
use pyrefly_types::display::LspDisplayMode;
use pyrefly_types::types::Type;
Expand All @@ -32,6 +28,7 @@ use crate::binding::binding::Key;
use crate::state::ide::insert_import_edit;
use crate::state::lsp::ImportFormat;
use crate::state::lsp::LocalRefactorCodeAction;
use crate::state::lsp::pytest::pytest_conftest_handles;
use crate::state::state::Transaction;

#[derive(Debug, Default)]
Expand Down Expand Up @@ -214,48 +211,6 @@ fn fixture_types_for_module(transaction: &Transaction<'_>, handle: &Handle) -> H
fixtures
}

fn conftest_handles(transaction: &Transaction<'_>, handle: &Handle) -> Vec<Handle> {
let module_path = handle.path();
let Some(mut dir) = module_path.as_path().parent() else {
return Vec::new();
};
let root = module_path
.root_of(handle.module())
.unwrap_or_else(|| dir.to_path_buf());
let is_memory = matches!(module_path.details(), ModulePathDetails::Memory(_));
let mut conftest_paths = Vec::new();
loop {
let conftest_pyi = dir.join("conftest.pyi");
let conftest_py = dir.join("conftest.py");
if is_memory {
conftest_paths.push(ModulePath::memory(conftest_pyi.clone()));
conftest_paths.push(ModulePath::memory(conftest_py.clone()));
} else {
if conftest_pyi.exists() {
conftest_paths.push(ModulePath::filesystem(conftest_pyi));
}
if conftest_py.exists() {
conftest_paths.push(ModulePath::filesystem(conftest_py));
}
}
if dir == root {
break;
}
let Some(parent) = dir.parent() else {
break;
};
dir = parent;
}
let mut handles = Vec::new();
for path in conftest_paths {
let config = transaction
.config_finder()
.python_file(ModuleNameWithKind::guaranteed(ModuleName::unknown()), &path);
handles.push(config.handle_from_module_path(path));
}
handles
}

fn import_edits_for_type(
transaction: &Transaction<'_>,
ast: &ModModule,
Expand Down Expand Up @@ -415,7 +370,7 @@ pub(crate) fn pytest_fixture_type_annotation_code_actions(
}

let mut fixture_types = fixture_types_for_module(transaction, handle);
for conftest_handle in conftest_handles(transaction, handle) {
for conftest_handle in pytest_conftest_handles(transaction, handle) {
let conftest_types = fixture_types_for_module(transaction, &conftest_handle);
for (name, ty) in conftest_types {
fixture_types.entry(name).or_insert(ty);
Expand Down
47 changes: 36 additions & 11 deletions pyrefly/lib/state/pytest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,36 @@ pub(crate) fn find_pytest_fixture_definitions_for_parameter(
identifier: &Identifier,
covering_nodes: &[AnyNodeRef],
) -> Vec<PytestFixtureDefinition> {
if !is_pytest_fixture_parameter_context(bindings, covering_nodes) {
return Vec::new();
}
let Some(pytest_info) = bindings.pytest_info() else {
return Vec::new();
};
let class_def = covering_nodes.iter().find_map(|node| match node {
AnyNodeRef::StmtClassDef(stmt) => Some(stmt),
_ => None,
});
let class_key = class_def.and_then(|def| class_key_for_definition(bindings, def));
let Some(fixture_class_key) =
pytest_info.visible_fixture_class_key(identifier.id(), class_key.as_ref())
else {
return Vec::new();
};

find_pytest_fixture_definitions_in_module(module, bindings, identifier.id(), fixture_class_key)
}

pub(crate) fn is_pytest_fixture_parameter_context(
bindings: &Bindings,
covering_nodes: &[AnyNodeRef],
) -> bool {
let function_def = covering_nodes.iter().find_map(|node| match node {
AnyNodeRef::StmtFunctionDef(stmt) => Some(stmt),
_ => None,
});
let Some(function_def) = function_def else {
return Vec::new();
return false;
};
let class_def = covering_nodes.iter().find_map(|node| match node {
AnyNodeRef::StmtClassDef(stmt) => Some(stmt),
Expand All @@ -240,23 +261,27 @@ pub(crate) fn find_pytest_fixture_definitions_for_parameter(
let class_is_test = class_def.map(|def| is_pytest_test_class(def));
let class_key = class_def.and_then(|def| class_key_for_definition(bindings, def));

if !is_pytest_fixture_function(function_def, class_key, pytest_info)
&& !is_pytest_test_function(function_def, class_is_test)
{
return Vec::new();
}
let Some(fixture_class_key) =
pytest_info.visible_fixture_class_key(identifier.id(), class_key.as_ref())
else {
bindings
.pytest_info()
.is_some_and(|pytest_info| is_pytest_fixture_function(function_def, class_key, pytest_info))
|| is_pytest_test_function(function_def, class_is_test)
}

pub(crate) fn find_pytest_fixture_definitions_in_module(
module: &ModModule,
bindings: &Bindings,
fixture_name: &Name,
fixture_class_key: Option<Idx<KeyClass>>,
) -> Vec<PytestFixtureDefinition> {
let Some(pytest_info) = bindings.pytest_info() else {
return Vec::new();
};

let mut matches = Vec::new();
collect_pytest_fixture_definitions_for_name(
&module.body,
bindings,
pytest_info,
identifier.id(),
fixture_name,
fixture_class_key,
&mut matches,
None,
Expand Down
33 changes: 33 additions & 0 deletions pyrefly/lib/test/lsp/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ use pyrefly_python::module::TextRangeWithModule;
use ruff_text_size::TextRange;
use ruff_text_size::TextSize;

use crate::state::require::Require;
use crate::state::state::State;
use crate::test::util::TestEnv;
use crate::test::util::code_frame_of_source_at_range;
use crate::test::util::extract_cursors_for_test;
use crate::test::util::get_batched_lsp_operations_report;
use crate::test::util::get_batched_lsp_operations_report_allow_error;
use crate::test::util::mk_multi_file_state_assert_no_errors;

fn get_test_report(state: &State, handle: &Handle, position: TextSize) -> String {
let defs = state
Expand Down Expand Up @@ -190,6 +192,37 @@ Definition Result:
);
}

#[test]
fn pytest_fixture_parameter_goes_to_conftest_fixture_definition() {
let conftest = r#"
import pytest # type: ignore

@pytest.fixture
def answer():
return 42
"#;
let code = r#"
def test_thing(answer):
# ^
assert answer == 42
"#;
let (handles, state) = mk_multi_file_state_assert_no_errors(
&[("main", code), ("conftest", conftest)],
Require::Exports,
);
let handle = handles.get("main").unwrap();
let position = extract_cursors_for_test(code).into_iter().next().unwrap();
assert_eq!(
r#"
Definition Result:
5 | def answer():
^^^^^^
"#
.trim(),
get_test_report(&state, handle, position).trim(),
);
}

#[test]
fn pytest_fixture_parameter_without_fixture_definition_uses_parameter_definition() {
let code = r#"
Expand Down
Loading