From 88a9c10fb6ae6fe1414e39caf6e1ac08f2bf4dec Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 16 Jun 2026 08:07:28 +0900 Subject: [PATCH] fix --- pyrefly/lib/state/lsp/pytest.rs | 90 ++++++++++++++++++- .../state/lsp/quick_fixes/pytest_fixture.rs | 49 +--------- pyrefly/lib/state/pytest.rs | 47 +++++++--- pyrefly/lib/test/lsp/definition.rs | 33 +++++++ 4 files changed, 159 insertions(+), 60 deletions(-) diff --git a/pyrefly/lib/state/lsp/pytest.rs b/pyrefly/lib/state/lsp/pytest.rs index c8a77143cd..218a7679a3 100644 --- a/pyrefly/lib/state/lsp/pytest.rs +++ b/pyrefly/lib/state/lsp/pytest.rs @@ -6,6 +6,10 @@ */ use pyrefly_build::handle::Handle; +use pyrefly_python::module_name::ModuleName; +use pyrefly_python::module_name::ModuleNameWithKind; +use pyrefly_python::module_path::ModulePath; +use pyrefly_python::module_path::ModulePathDetails; use pyrefly_python::symbol_kind::SymbolKind; use ruff_python_ast::AnyNodeRef; use ruff_python_ast::Identifier; @@ -16,14 +20,61 @@ use vec1::Vec1; use super::DefinitionMetadata; use super::FindDefinitionItemWithDocstring; use crate::state::pytest::find_pytest_fixture_definitions_for_parameter; +use crate::state::pytest::find_pytest_fixture_definitions_in_module; use crate::state::pytest::find_pytest_fixture_parameter_references; +use crate::state::pytest::is_pytest_fixture_parameter_context; use crate::state::state::Transaction; +pub(crate) fn pytest_conftest_handles( + transaction: &Transaction<'_>, + handle: &Handle, +) -> Vec { + let module_path = handle.path(); + let Some(mut dir) = module_path.as_path().parent() else { + return Vec::new(); + }; + let root = module_path + .root_of(handle.module()) + .unwrap_or_else(|| dir.to_path_buf()); + let is_memory = matches!(module_path.details(), ModulePathDetails::Memory(_)); + let mut conftest_paths = Vec::new(); + loop { + let conftest_pyi = dir.join("conftest.pyi"); + let conftest_py = dir.join("conftest.py"); + if is_memory { + conftest_paths.push(ModulePath::memory(conftest_pyi.clone())); + conftest_paths.push(ModulePath::memory(conftest_py.clone())); + } else { + if conftest_pyi.exists() { + conftest_paths.push(ModulePath::filesystem(conftest_pyi)); + } + if conftest_py.exists() { + conftest_paths.push(ModulePath::filesystem(conftest_py)); + } + } + if dir == root { + break; + } + let Some(parent) = dir.parent() else { + break; + }; + dir = parent; + } + let mut handles = Vec::new(); + for path in conftest_paths { + let config = transaction + .config_finder() + .python_file(ModuleNameWithKind::guaranteed(ModuleName::unknown()), &path); + handles.push(config.handle_from_module_path(path)); + } + handles +} + impl<'a> Transaction<'a> { /// Resolve a pytest fixture parameter to the fixture functions that can provide it. /// /// This runs during definition lookup. The common non-pytest path is cheap because we first - /// ask bindings for pytest metadata, which is absent in modules that do not import pytest. + /// require either a pytest fixture function or a test-named function. pub(super) fn pytest_fixture_definitions_for_parameter( &self, handle: &Handle, @@ -32,6 +83,9 @@ impl<'a> Transaction<'a> { ) -> Option> { let mod_module = self.get_ast(handle)?; let bindings = self.get_bindings(handle)?; + if !is_pytest_fixture_parameter_context(&bindings, covering_nodes) { + return None; + } let matches = find_pytest_fixture_definitions_for_parameter( mod_module.as_ref(), &bindings, @@ -39,7 +93,7 @@ impl<'a> Transaction<'a> { covering_nodes, ); let module_info = self.get_module_info(handle)?; - let definitions = matches + let mut definitions: Vec<_> = matches .into_iter() .map(|fixture| FindDefinitionItemWithDocstring { metadata: DefinitionMetadata::Variable(Some(SymbolKind::Function)), @@ -49,6 +103,38 @@ impl<'a> Transaction<'a> { display_name: Some(fixture.name.as_str().to_owned()), }) .collect(); + if definitions.is_empty() { + for conftest_handle in pytest_conftest_handles(self, handle) { + let Some(conftest_ast) = self.get_ast(&conftest_handle) else { + continue; + }; + let Some(conftest_bindings) = self.get_bindings(&conftest_handle) else { + continue; + }; + let Some(conftest_module_info) = self.get_module_info(&conftest_handle) else { + continue; + }; + definitions.extend( + find_pytest_fixture_definitions_in_module( + conftest_ast.as_ref(), + &conftest_bindings, + identifier.id(), + None, + ) + .into_iter() + .map(|fixture| FindDefinitionItemWithDocstring { + metadata: DefinitionMetadata::Variable(Some(SymbolKind::Function)), + definition_range: fixture.range, + module: conftest_module_info.clone(), + docstring_range: fixture.docstring_range, + display_name: Some(fixture.name.as_str().to_owned()), + }), + ); + if !definitions.is_empty() { + break; + } + } + } Vec1::try_from_vec(definitions).ok() } diff --git a/pyrefly/lib/state/lsp/quick_fixes/pytest_fixture.rs b/pyrefly/lib/state/lsp/quick_fixes/pytest_fixture.rs index 815b8756c2..a8151484e6 100644 --- a/pyrefly/lib/state/lsp/quick_fixes/pytest_fixture.rs +++ b/pyrefly/lib/state/lsp/quick_fixes/pytest_fixture.rs @@ -10,10 +10,6 @@ use std::collections::HashSet; use dupe::Dupe; use pyrefly_build::handle::Handle; -use pyrefly_python::module_name::ModuleName; -use pyrefly_python::module_name::ModuleNameWithKind; -use pyrefly_python::module_path::ModulePath; -use pyrefly_python::module_path::ModulePathDetails; use pyrefly_python::short_identifier::ShortIdentifier; use pyrefly_types::display::LspDisplayMode; use pyrefly_types::types::Type; @@ -32,6 +28,7 @@ use crate::binding::binding::Key; use crate::state::ide::insert_import_edit; use crate::state::lsp::ImportFormat; use crate::state::lsp::LocalRefactorCodeAction; +use crate::state::lsp::pytest::pytest_conftest_handles; use crate::state::state::Transaction; #[derive(Debug, Default)] @@ -214,48 +211,6 @@ fn fixture_types_for_module(transaction: &Transaction<'_>, handle: &Handle) -> H fixtures } -fn conftest_handles(transaction: &Transaction<'_>, handle: &Handle) -> Vec { - let module_path = handle.path(); - let Some(mut dir) = module_path.as_path().parent() else { - return Vec::new(); - }; - let root = module_path - .root_of(handle.module()) - .unwrap_or_else(|| dir.to_path_buf()); - let is_memory = matches!(module_path.details(), ModulePathDetails::Memory(_)); - let mut conftest_paths = Vec::new(); - loop { - let conftest_pyi = dir.join("conftest.pyi"); - let conftest_py = dir.join("conftest.py"); - if is_memory { - conftest_paths.push(ModulePath::memory(conftest_pyi.clone())); - conftest_paths.push(ModulePath::memory(conftest_py.clone())); - } else { - if conftest_pyi.exists() { - conftest_paths.push(ModulePath::filesystem(conftest_pyi)); - } - if conftest_py.exists() { - conftest_paths.push(ModulePath::filesystem(conftest_py)); - } - } - if dir == root { - break; - } - let Some(parent) = dir.parent() else { - break; - }; - dir = parent; - } - let mut handles = Vec::new(); - for path in conftest_paths { - let config = transaction - .config_finder() - .python_file(ModuleNameWithKind::guaranteed(ModuleName::unknown()), &path); - handles.push(config.handle_from_module_path(path)); - } - handles -} - fn import_edits_for_type( transaction: &Transaction<'_>, ast: &ModModule, @@ -415,7 +370,7 @@ pub(crate) fn pytest_fixture_type_annotation_code_actions( } let mut fixture_types = fixture_types_for_module(transaction, handle); - for conftest_handle in conftest_handles(transaction, handle) { + for conftest_handle in pytest_conftest_handles(transaction, handle) { let conftest_types = fixture_types_for_module(transaction, &conftest_handle); for (name, ty) in conftest_types { fixture_types.entry(name).or_insert(ty); diff --git a/pyrefly/lib/state/pytest.rs b/pyrefly/lib/state/pytest.rs index 5815d5db97..47cddeabe5 100644 --- a/pyrefly/lib/state/pytest.rs +++ b/pyrefly/lib/state/pytest.rs @@ -223,15 +223,36 @@ pub(crate) fn find_pytest_fixture_definitions_for_parameter( identifier: &Identifier, covering_nodes: &[AnyNodeRef], ) -> Vec { + if !is_pytest_fixture_parameter_context(bindings, covering_nodes) { + return Vec::new(); + } let Some(pytest_info) = bindings.pytest_info() else { return Vec::new(); }; + let class_def = covering_nodes.iter().find_map(|node| match node { + AnyNodeRef::StmtClassDef(stmt) => Some(stmt), + _ => None, + }); + let class_key = class_def.and_then(|def| class_key_for_definition(bindings, def)); + let Some(fixture_class_key) = + pytest_info.visible_fixture_class_key(identifier.id(), class_key.as_ref()) + else { + return Vec::new(); + }; + + find_pytest_fixture_definitions_in_module(module, bindings, identifier.id(), fixture_class_key) +} + +pub(crate) fn is_pytest_fixture_parameter_context( + bindings: &Bindings, + covering_nodes: &[AnyNodeRef], +) -> bool { let function_def = covering_nodes.iter().find_map(|node| match node { AnyNodeRef::StmtFunctionDef(stmt) => Some(stmt), _ => None, }); let Some(function_def) = function_def else { - return Vec::new(); + return false; }; let class_def = covering_nodes.iter().find_map(|node| match node { AnyNodeRef::StmtClassDef(stmt) => Some(stmt), @@ -240,23 +261,27 @@ pub(crate) fn find_pytest_fixture_definitions_for_parameter( let class_is_test = class_def.map(|def| is_pytest_test_class(def)); let class_key = class_def.and_then(|def| class_key_for_definition(bindings, def)); - if !is_pytest_fixture_function(function_def, class_key, pytest_info) - && !is_pytest_test_function(function_def, class_is_test) - { - return Vec::new(); - } - let Some(fixture_class_key) = - pytest_info.visible_fixture_class_key(identifier.id(), class_key.as_ref()) - else { + bindings + .pytest_info() + .is_some_and(|pytest_info| is_pytest_fixture_function(function_def, class_key, pytest_info)) + || is_pytest_test_function(function_def, class_is_test) +} + +pub(crate) fn find_pytest_fixture_definitions_in_module( + module: &ModModule, + bindings: &Bindings, + fixture_name: &Name, + fixture_class_key: Option>, +) -> Vec { + let Some(pytest_info) = bindings.pytest_info() else { return Vec::new(); }; - let mut matches = Vec::new(); collect_pytest_fixture_definitions_for_name( &module.body, bindings, pytest_info, - identifier.id(), + fixture_name, fixture_class_key, &mut matches, None, diff --git a/pyrefly/lib/test/lsp/definition.rs b/pyrefly/lib/test/lsp/definition.rs index 34bbe0e8aa..255e1e2ccd 100644 --- a/pyrefly/lib/test/lsp/definition.rs +++ b/pyrefly/lib/test/lsp/definition.rs @@ -12,12 +12,14 @@ use pyrefly_python::module::TextRangeWithModule; use ruff_text_size::TextRange; use ruff_text_size::TextSize; +use crate::state::require::Require; use crate::state::state::State; use crate::test::util::TestEnv; use crate::test::util::code_frame_of_source_at_range; use crate::test::util::extract_cursors_for_test; use crate::test::util::get_batched_lsp_operations_report; use crate::test::util::get_batched_lsp_operations_report_allow_error; +use crate::test::util::mk_multi_file_state_assert_no_errors; fn get_test_report(state: &State, handle: &Handle, position: TextSize) -> String { let defs = state @@ -190,6 +192,37 @@ Definition Result: ); } +#[test] +fn pytest_fixture_parameter_goes_to_conftest_fixture_definition() { + let conftest = r#" +import pytest # type: ignore + +@pytest.fixture +def answer(): + return 42 +"#; + let code = r#" +def test_thing(answer): +# ^ + assert answer == 42 +"#; + let (handles, state) = mk_multi_file_state_assert_no_errors( + &[("main", code), ("conftest", conftest)], + Require::Exports, + ); + let handle = handles.get("main").unwrap(); + let position = extract_cursors_for_test(code).into_iter().next().unwrap(); + assert_eq!( + r#" +Definition Result: +5 | def answer(): + ^^^^^^ +"# + .trim(), + get_test_report(&state, handle, position).trim(), + ); +} + #[test] fn pytest_fixture_parameter_without_fixture_definition_uses_parameter_definition() { let code = r#"