Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions os/src/arch/arch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use core::fmt::Debug;
use crate::{
arch::{address::UA, cpu_ops::CpuOps, task::ExecStackLayout, virtual_memory::VirtualMemory},
kernel::syscall::syscall_frame::SyscallFrame,
mm::page_table::PagingError,
uapi::signal::MContextT,
};

Expand Down Expand Up @@ -87,13 +88,13 @@ pub trait Arch: CpuOps + VirtualMemory {
/// - `src` 必须是有效的用户空间虚拟地址
/// - `dst` 必须指向足够大的内核缓冲区
/// - `len` 字节必须在合法范围内
unsafe fn copy_from_user(src: UA, dst: *mut u8, len: usize) -> Result<(), ()>;
unsafe fn copy_from_user(src: UA, dst: *mut u8, len: usize) -> Result<(), PagingError>;

/// 尝试从用户空间复制数据(非阻塞版本,不处理缺页)
///
/// # Safety
/// 同上
unsafe fn try_copy_from_user(src: UA, dst: *mut u8, len: usize) -> Result<(), ()>;
unsafe fn try_copy_from_user(src: UA, dst: *mut u8, len: usize) -> Result<(), PagingError>;

/// 从内核空间复制数据到用户空间
///
Expand All @@ -102,13 +103,17 @@ pub trait Arch: CpuOps + VirtualMemory {
/// - `dst` 必须是有效的用户空间虚拟地址
/// - `src` 必须指向有效内核数据
/// - `len` 字节必须在合法范围内
unsafe fn copy_to_user(src: *const u8, dst: UA, len: usize) -> Result<(), ()>;
unsafe fn copy_to_user(src: *const u8, dst: UA, len: usize) -> Result<(), PagingError>;

/// 从用户空间复制以 '\0' 结尾的字符串
///
/// # Safety
/// 同上
unsafe fn copy_strn_from_user(src: UA, dst: *mut u8, max_len: usize) -> Result<usize, ()>;
unsafe fn copy_strn_from_user(
src: UA,
dst: *mut u8,
max_len: usize,
) -> Result<usize, PagingError>;

// ---- 系统信息 ----

Expand Down
43 changes: 24 additions & 19 deletions os/src/arch/arch_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ macro_rules! impl_arch {
($arch:ty, $process_space:ty, $kernel_space:ty) => {
use $crate::arch::virtual_memory::VirtualMemory;
use $crate::mm::address::Ppn;
use $crate::mm::page_table::PagingError;
use $crate::sync::SpinLock;

lazy_static::lazy_static! {
Expand Down Expand Up @@ -47,11 +48,11 @@ macro_rules! impl_arch {
src: $crate::arch::address::UA,
dst: *mut u8,
len: usize,
) -> Result<(), ()> {
) -> Result<(), PagingError> {
let src = src.as_usize();
validate_user_copy_range(src, len, false)?;
if len != 0 && dst.is_null() {
return Err(());
return Err(PagingError::InvalidAddress);
}
let _guard = trap::SumGuard::new();
unsafe { core::ptr::copy_nonoverlapping(src as *const u8, dst, len) };
Expand All @@ -62,19 +63,19 @@ macro_rules! impl_arch {
src: $crate::arch::address::UA,
dst: *mut u8,
len: usize,
) -> Result<(), ()> {
) -> Result<(), PagingError> {
unsafe { Self::copy_from_user(src, dst, len) }
}

unsafe fn copy_to_user(
src: *const u8,
dst: $crate::arch::address::UA,
len: usize,
) -> Result<(), ()> {
) -> Result<(), PagingError> {
let dst = dst.as_usize();
validate_user_copy_range(dst, len, true)?;
if len != 0 && src.is_null() {
return Err(());
return Err(PagingError::InvalidAddress);
}
let _guard = trap::SumGuard::new();
unsafe { core::ptr::copy_nonoverlapping(src, dst as *mut u8, len) };
Expand All @@ -85,18 +86,18 @@ macro_rules! impl_arch {
src: $crate::arch::address::UA,
dst: *mut u8,
max_len: usize,
) -> Result<usize, ()> {
) -> Result<usize, PagingError> {
let src = src.as_usize();
if !(constant::USER_BASE..=<$arch as VirtualMemory>::USER_TOP).contains(&src) {
return Err(());
return Err(PagingError::InvalidAddress);
}
if max_len != 0 && dst.is_null() {
return Err(());
return Err(PagingError::InvalidAddress);
}
let _guard = trap::SumGuard::new();
let mut i = 0;
while i < max_len {
let cur = src.checked_add(i).ok_or(())?;
let cur = src.checked_add(i).ok_or(PagingError::InvalidAddress)?;
validate_user_copy_range(cur, 1, false)?;
let byte = unsafe { core::ptr::read_volatile(cur as *const u8) };
unsafe { *dst.add(i) = byte };
Expand Down Expand Up @@ -138,42 +139,46 @@ macro_rules! impl_arch {
}
}

fn validate_user_copy_range(start: usize, len: usize, write: bool) -> Result<(), ()> {
fn validate_user_copy_range(
start: usize,
len: usize,
write: bool,
) -> Result<(), PagingError> {
use $crate::mm::address::{PageNum, VA, Vpn};
use $crate::mm::page_table::{PageTableInner, UniversalPTEFlag};

if len == 0 {
return Ok(());
}
if !(constant::USER_BASE..=<$arch as VirtualMemory>::USER_TOP).contains(&start) {
return Err(());
return Err(PagingError::InvalidAddress);
}
let end = start.checked_add(len).ok_or(())?;
let last = end.checked_sub(1).ok_or(())?;
let end = start.checked_add(len).ok_or(PagingError::InvalidAddress)?;
let last = end.checked_sub(1).ok_or(PagingError::InvalidAddress)?;
if last > <$arch as VirtualMemory>::USER_TOP {
return Err(());
return Err(PagingError::InvalidAddress);
}

let space = $crate::kernel::current_memory_space();
let guard = space.lock();
let mut cur = start;
while cur < end {
let vpn = Vpn::from_addr_floor(VA::from_usize(cur));
let (_, _, flags) = guard.page_table().walk(vpn).map_err(|_| ())?;
let (_, _, flags) = guard.page_table().walk(vpn)?;
let required = UniversalPTEFlag::VALID | UniversalPTEFlag::USER_ACCESSIBLE;
if !flags.contains(required) {
return Err(());
return Err(PagingError::PermissionDenied);
}
if write {
if !flags.contains(UniversalPTEFlag::WRITEABLE) {
return Err(());
return Err(PagingError::PermissionDenied);
}
} else if !flags.contains(UniversalPTEFlag::READABLE) {
return Err(());
return Err(PagingError::PermissionDenied);
}
let next_page = (cur & !($crate::config::PAGE_SIZE - 1))
.checked_add($crate::config::PAGE_SIZE)
.ok_or(())?;
.ok_or(PagingError::InvalidAddress)?;
cur = core::cmp::min(next_page, end);
}
Ok(())
Expand Down
51 changes: 29 additions & 22 deletions os/src/arch/loongarch/mm/page_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,18 @@ impl PageTableInnerTrait<PageTableEntry> for PageTableInner {
}

/// 创建新的用户页表
fn new() -> Self {
let frame = alloc_frame().expect("Failed to allocate root page table frame");
fn new() -> PagingResult<Self> {
let frame = alloc_frame().ok_or(PagingError::FrameAllocFailed)?;
let root_ppn = frame.ppn();

// 清零根页表
Self::clear_page_table(root_ppn);

Self {
Ok(Self {
root_ppn,
frames: alloc::vec![frame],
is_user: true,
}
})
}

/// 从已有的 PPN 创建页表(不拥有帧所有权)
Expand All @@ -183,18 +183,18 @@ impl PageTableInnerTrait<PageTableEntry> for PageTableInner {
}

/// 创建新的内核页表
fn new_as_kernel_table() -> Self {
let frame = alloc_frame().expect("Failed to allocate kernel page table frame");
fn new_as_kernel_table() -> PagingResult<Self> {
let frame = alloc_frame().ok_or(PagingError::FrameAllocFailed)?;
let root_ppn = frame.ppn();

// 清零根页表
Self::clear_page_table(root_ppn);

Self {
Ok(Self {
root_ppn,
frames: alloc::vec![frame],
is_user: false,
}
})
}

fn root_ppn(&self) -> Ppn {
Expand Down Expand Up @@ -251,7 +251,7 @@ impl PageTableInnerTrait<PageTableEntry> for PageTableInner {
ppn: Ppn,
_page_size: PageSize,
flags: UniversalPTEFlag,
) -> PagingResult<()> {
) -> Result<(), PagingError> {
// 验证标志位:叶子节点必须至少设置可读或可执行
if !flags.intersects(
UniversalPTEFlag::READABLE | UniversalPTEFlag::WRITEABLE | UniversalPTEFlag::EXECUTABLE,
Expand Down Expand Up @@ -305,7 +305,7 @@ impl PageTableInnerTrait<PageTableEntry> for PageTableInner {
}

/// 解除虚拟页的映射
fn unmap(&mut self, vpn: Vpn) -> PagingResult<()> {
fn unmap(&mut self, vpn: Vpn) -> Result<(), PagingError> {
let mut current_ppn = self.root_ppn;
let vpn_value = vpn.as_usize();

Expand Down Expand Up @@ -337,13 +337,13 @@ impl PageTableInnerTrait<PageTableEntry> for PageTableInner {
target_ppn: Ppn,
page_size: PageSize,
flags: UniversalPTEFlag,
) -> PagingResult<()> {
) -> Result<(), PagingError> {
self.unmap(vpn)?;
self.map(vpn, target_ppn, page_size, flags)
}

/// 更新页表项标志位
fn update_flags(&mut self, vpn: Vpn, flags: UniversalPTEFlag) -> PagingResult<()> {
fn update_flags(&mut self, vpn: Vpn, flags: UniversalPTEFlag) -> Result<(), PagingError> {
let mut current_ppn = self.root_ppn;
let vpn_value = vpn.as_usize();

Expand Down Expand Up @@ -402,7 +402,7 @@ impl PageTableInner {
page_size: PageSize,
flags: UniversalPTEFlag,
_batch: Option<&mut TlbBatchContext>,
) -> PagingResult<()> {
) -> Result<(), PagingError> {
<Self as PageTableInnerTrait<PageTableEntry>>::map(self, vpn, ppn, page_size, flags)?;
<Self as PageTableInnerTrait<PageTableEntry>>::tlb_flush(vpn);
Ok(())
Expand All @@ -413,7 +413,7 @@ impl PageTableInner {
&mut self,
vpn: Vpn,
_batch: Option<&mut TlbBatchContext>,
) -> PagingResult<()> {
) -> Result<(), PagingError> {
<Self as PageTableInnerTrait<PageTableEntry>>::unmap(self, vpn)?;
<Self as PageTableInnerTrait<PageTableEntry>>::tlb_flush(vpn);
Ok(())
Expand All @@ -425,7 +425,7 @@ impl PageTableInner {
vpn: Vpn,
flags: UniversalPTEFlag,
_batch: Option<&mut TlbBatchContext>,
) -> PagingResult<()> {
) -> Result<(), PagingError> {
<Self as PageTableInnerTrait<PageTableEntry>>::update_flags(self, vpn, flags)?;
<Self as PageTableInnerTrait<PageTableEntry>>::tlb_flush(vpn);
Ok(())
Expand Down Expand Up @@ -530,9 +530,16 @@ mod page_table_tests {
use crate::mm::page_table::PageTableInner as PageTableInnerTrait;
use crate::{kassert, test_case};

fn new_page_table() -> PageTableInner {
match PageTableInner::new() {
Ok(pt) => pt,
Err(err) => panic!("failed to create test page table: {:?}", err),
}
}

// 1. 页表创建测试
test_case!(test_pt_create, {
let pt = PageTableInner::new();
let pt = new_page_table();
// 根 PPN 应该有效 (大于 0)
kassert!(pt.root_ppn().as_usize() > 0);
// 默认创建为用户页表
Expand All @@ -541,7 +548,7 @@ mod page_table_tests {

// 2. 映射与转换测试
test_case!(test_pt_map_translate, {
let mut pt = PageTableInner::new();
let mut pt = new_page_table();
let vpn = Vpn::from_usize(0x1000);
let ppn = Ppn::from_usize(0x80000);

Expand All @@ -560,7 +567,7 @@ mod page_table_tests {

// 3. 解除映射测试
test_case!(test_pt_unmap, {
let mut pt = PageTableInner::new();
let mut pt = new_page_table();
let vpn = Vpn::from_usize(0x1000);
let ppn = Ppn::from_usize(0x80000);

Expand All @@ -580,7 +587,7 @@ mod page_table_tests {

// 4. 错误测试:已映射
test_case!(test_pt_error_already_mapped, {
let mut pt = PageTableInner::new();
let mut pt = new_page_table();
let vpn = Vpn::from_usize(0x1000);

// 第一次映射成功
Expand All @@ -604,7 +611,7 @@ mod page_table_tests {

// 5. 页表遍历 (Walk) 测试
test_case!(test_pt_walk, {
let mut pt = PageTableInner::new();
let mut pt = new_page_table();
let vpn = Vpn::from_usize(0x1000);
let ppn = Ppn::from_usize(0x80000);
let original_flags = UniversalPTEFlag::kernel_rw();
Expand All @@ -629,7 +636,7 @@ mod page_table_tests {

// 6. 更新标志位测试
test_case!(test_pt_update_flags, {
let mut pt = PageTableInner::new();
let mut pt = new_page_table();
let vpn = Vpn::from_usize(0x1000);
let ppn = Ppn::from_usize(0x80000);

Expand All @@ -652,7 +659,7 @@ mod page_table_tests {

// 7. 多重映射测试
test_case!(test_pt_multiple_mappings, {
let mut pt = PageTableInner::new();
let mut pt = new_page_table();

// 映射多个 VPN
for i in 0..10 {
Expand Down
Loading
Loading