diff --git a/Cargo.lock b/Cargo.lock index e96a5158b4..b0802db5b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -431,6 +431,7 @@ dependencies = [ "bcm2835-sdhci", "log", "simple-sdmmc", + "simple-ahci", ] [[package]] @@ -2356,6 +2357,18 @@ dependencies = [ "volatile 0.6.1", ] +[[package]] +name = "simple-ahci" +version = "0.1.0" +source = "git+https://github.com/Starry-OS/simple-ahci.git?rev=36d0979#36d0979fedb17c7846b78d1a63f944da31f96186" +dependencies = [ + "bitfield-struct", + "log", + "thiserror", + "volatile 0.6.1", +] + + [[package]] name = "slab_allocator" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index d9c36814e7..ef01765528 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,8 +79,9 @@ axdriver_net = { git = "https://github.com/arceos-org/axdriver_crates.git", tag axdriver_pci = { git = "https://github.com/arceos-org/axdriver_crates.git", tag = "dev-v01" } axdriver_virtio = { git = "https://github.com/arceos-org/axdriver_crates.git", tag = "dev-v01" } axdriver_vsock = { git = "https://github.com/arceos-org/axdriver_crates.git", tag = "dev-v01" } -axerrno = "0.1" -axio = "0.1" +axerrno = "0.2" +axfs-ng-vfs = "0.1" +axio = { git = "https://github.com/arceos-org/axio.git", tag = "dev-v02" } axklib = { git = "https://github.com/arceos-hypervisor/axklib.git" } # FIXME: pin to a specific commit or tag axplat = { git = "https://github.com/arceos-org/axplat_crates.git", tag = "dev-v03" } axplat-aarch64-bsta1000b = { git = "https://github.com/arceos-org/axplat_crates.git", tag = "dev-v03" } @@ -91,11 +92,13 @@ axplat-loongarch64-qemu-virt = { git = "https://github.com/arceos-org/axplat_cra axplat-riscv64-qemu-virt = { git = "https://github.com/arceos-org/axplat_crates.git", tag = "dev-v03" } axplat-x86-pc = { git = "https://github.com/arceos-org/axplat_crates.git", tag = "dev-v03" } axpoll = "0.1" +axwatchdog = { git = "https://github.com/kylin-x-kernel/axwatchdog.git", branch = "lzx/dev"} bindgen = "0.72" cfg-if = "1.0" chrono = { version = "0.4", default-features = false } crate_interface = "0.1.4" ctor_bare = "0.2" +enum_dispatch = "0.3" event-listener = { version = "5.4.0", default-features = false } kernel_guard = "0.1" kspin = "0.1" @@ -104,7 +107,9 @@ lazy_static = { version = "1.5", features = ["spin_no_std"] } lock_api = { version = "0.4", default-features = false } log = "0.4" memory_addr = "0.4" -page_table_multiarch = "0.5" +page_table_multiarch = { git = "https://github.com/arceos-org/page_table_multiarch.git", tag = "dev-v05", features = [ + "axerrno", +] } percpu = "0.2" scope-local = "0.1" spin = "0.10" diff --git a/api/arceos_api/Cargo.toml b/api/arceos_api/Cargo.toml index bf30e09efd..b9be846b53 100644 --- a/api/arceos_api/Cargo.toml +++ b/api/arceos_api/Cargo.toml @@ -22,8 +22,6 @@ fs = ["dep:axfs", "dep:axdriver", "axfeat/fs"] net = ["dep:axnet", "dep:axdriver", "axfeat/net"] display = ["dep:axdisplay", "dep:axdriver", "axfeat/display"] -myfs = ["axfeat/myfs"] - # Use dummy functions if the feature is not enabled dummy-if-not-enabled = [] diff --git a/api/axfeat/Cargo.toml b/api/axfeat/Cargo.toml index bc101e5aa6..4732684813 100644 --- a/api/axfeat/Cargo.toml +++ b/api/axfeat/Cargo.toml @@ -61,9 +61,9 @@ fs = [ "dep:axfs", "axruntime/fs", ] # TODO: try to remove "paging" -myfs = ["axfs?/myfs"] -ext4fs = ["axfs?/ext4fs"] -fatfs = ["axfs?/fatfs"] +fs-ext4 = ["fs", "axfs/ext4"] +fs-fat = ["fs", "axfs/fat"] +fs-times = ["fs", "axfs/times"] # Networking net = ["alloc", "paging", "axdriver/virtio-net", "dep:axnet", "axruntime/net"] @@ -98,12 +98,16 @@ driver-sdmmc = ["axdriver?/sdmmc"] driver-ixgbe = ["axdriver?/ixgbe"] driver-fxmac = ["axdriver?/fxmac"] # fxmac ethernet driver for PhytiumPi driver-bcm2835-sdhci = ["axdriver?/bcm2835-sdhci"] +driver-ahci = ["axdriver?/ahci"] driver-dyn = ["paging", "axruntime/driver-dyn", "axdriver/dyn"] # Backtrace dwarf = ["alloc", "axbacktrace/dwarf"] +# Lockup detect +watchdog = ["axruntime/watchdog", "axtask/watchdog"] + [dependencies] axalloc = { workspace = true, optional = true } axbacktrace.workspace = true diff --git a/examples/shell/Cargo.toml b/examples/shell/Cargo.toml index 73c10f06d2..944d8417a1 100644 --- a/examples/shell/Cargo.toml +++ b/examples/shell/Cargo.toml @@ -7,16 +7,7 @@ authors = ["Yuekai Jia "] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -use-ramfs = [ - "axstd/myfs", - "dep:axfs_vfs", - "dep:axfs_ramfs", - "dep:crate_interface", -] default = [] [dependencies] -axfs_vfs = { version = "0.1", optional = true } -axfs_ramfs = { version = "0.1", optional = true } -crate_interface = { workspace = true, optional = true } axstd = { workspace = true, features = ["alloc", "fs"], optional = true } diff --git a/examples/shell/src/main.rs b/examples/shell/src/main.rs index 56e999bc99..6e92e55d8a 100644 --- a/examples/shell/src/main.rs +++ b/examples/shell/src/main.rs @@ -17,9 +17,6 @@ fn path_to_str(path: &str) -> &str { mod cmd; -#[cfg(feature = "use-ramfs")] -mod ramfs; - use std::io::prelude::*; const LF: u8 = b'\n'; diff --git a/examples/shell/src/ramfs.rs b/examples/shell/src/ramfs.rs deleted file mode 100644 index 450cafaebe..0000000000 --- a/examples/shell/src/ramfs.rs +++ /dev/null @@ -1,15 +0,0 @@ -extern crate alloc; - -use alloc::sync::Arc; -use axfs_ramfs::RamFileSystem; -use axfs_vfs::VfsOps; -use std::os::arceos::api::fs::{AxDisk, MyFileSystemIf}; - -struct MyFileSystemIfImpl; - -#[crate_interface::impl_interface] -impl MyFileSystemIf for MyFileSystemIfImpl { - fn new_myfs(_disk: AxDisk) -> Arc { - Arc::new(RamFileSystem::new()) - } -} diff --git a/modules/axdriver/Cargo.toml b/modules/axdriver/Cargo.toml index c721055cf4..27a8c02bc1 100644 --- a/modules/axdriver/Cargo.toml +++ b/modules/axdriver/Cargo.toml @@ -33,6 +33,7 @@ virtio-socket = ["vsock", "virtio", "axdriver_virtio/socket"] ramdisk = ["block", "axdriver_block/ramdisk"] bcm2835-sdhci = ["block", "axdriver_block/bcm2835-sdhci"] sdmmc = ["block", "axdriver_block/sdmmc", "dep:axhal", "dep:axconfig"] +ahci = ["block", "axdriver_block/ahci", "dep:axhal", "dep:axconfig"] ixgbe = ["net", "axdriver_net/ixgbe", "dep:axalloc", "dep:axhal", "dep:axdma"] fxmac = ["net", "axdriver_net/fxmac", "dep:axalloc", "dep:axhal", "dep:axdma"] # more devices example: e1000 = ["net", "axdriver_net/e1000"] diff --git a/modules/axdriver/build.rs b/modules/axdriver/build.rs index 7db8268643..6788e530b4 100644 --- a/modules/axdriver/build.rs +++ b/modules/axdriver/build.rs @@ -1,5 +1,5 @@ const NET_DEV_FEATURES: &[&str] = &["fxmac", "ixgbe", "virtio-net"]; -const BLOCK_DEV_FEATURES: &[&str] = &["ramdisk", "sdmmc", "bcm2835-sdhci", "virtio-blk"]; +const BLOCK_DEV_FEATURES: &[&str] = &["ahci", "ramdisk", "sdmmc", "bcm2835-sdhci", "virtio-blk"]; const DISPLAY_DEV_FEATURES: &[&str] = &["virtio-gpu"]; const INPUT_DEV_FEATURES: &[&str] = &["virtio-input"]; const VSOCK_DEV_FEATURES: &[&str] = &["virtio-socket"]; diff --git a/modules/axdriver/src/drivers.rs b/modules/axdriver/src/drivers.rs index b262cfc249..c68be62602 100644 --- a/modules/axdriver/src/drivers.rs +++ b/modules/axdriver/src/drivers.rs @@ -97,6 +97,43 @@ cfg_if::cfg_if! { } } +cfg_if::cfg_if! { + if #[cfg(block_dev = "ahci")] { + pub struct AhciHalImpl; + impl axdriver_block::ahci::AhciHal for AhciHalImpl { + fn virt_to_phys(va: usize) -> usize { + axhal::mem::virt_to_phys(va.into()).as_usize() + } + + fn current_ms() -> u64 { + axhal::time::monotonic_time_nanos() / 1_000_000 + } + + fn flush_dcache() { + #[cfg(target_arch = "loongarch64")] + unsafe { + // LoongArch64: Ensure data cache operations are synchronized for AHCI DMA coherency. + core::arch::asm!("dbar 0"); + } + } + } + + pub struct AhciDriver; + register_block_driver!(AhciDriver, axdriver_block::ahci::AhciDriver); + + impl DriverProbe for AhciDriver { + fn probe_global() -> Option { + let ahci = unsafe { + axdriver_block::ahci::AhciDriver::::try_new( + axhal::mem::phys_to_virt(axconfig::devices::AHCI_PADDR.into()).into(), + )? + }; + Some(AxDeviceEnum::from_block(ahci)) + } + } + } +} + cfg_if::cfg_if! { if #[cfg(block_dev = "bcm2835-sdhci")]{ pub struct BcmSdhciDriver; diff --git a/modules/axdriver/src/macros.rs b/modules/axdriver/src/macros.rs index 65b487f78f..c5defbd22c 100644 --- a/modules/axdriver/src/macros.rs +++ b/modules/axdriver/src/macros.rs @@ -85,6 +85,11 @@ macro_rules! for_each_drivers { type $drv_type = crate::drivers::SdMmcDriver; $code } + #[cfg(block_dev = "ahci")] + { + type $drv_type = crate::drivers::AhciDriver; + $code + } #[cfg(block_dev = "bcm2835-sdhci")] { type $drv_type = crate::drivers::BcmSdhciDriver; diff --git a/modules/axfs/Cargo.toml b/modules/axfs/Cargo.toml index 7104e05d72..f6faa12093 100644 --- a/modules/axfs/Cargo.toml +++ b/modules/axfs/Cargo.toml @@ -2,60 +2,52 @@ name = "axfs" version.workspace = true edition.workspace = true -authors = ["Yuekai Jia "] +authors = ["Mivik "] description = "ArceOS filesystem module" license.workspace = true homepage.workspace = true -repository = "https://github.com/arceos-org/arceos/tree/main/modules/axfs" -documentation = "https://arceos-org.github.io/arceos/axfs/index.html" [features] -devfs = ["dep:axfs_devfs"] -ramfs = ["dep:axfs_ramfs"] -procfs = ["dep:axfs_ramfs"] -sysfs = ["dep:axfs_ramfs"] -ext4fs = ["dep:lwext4_rust"] -fatfs = ["dep:fatfs"] -myfs = ["dep:crate_interface"] -use-ramdisk = [] - -default = ["devfs", "ramfs", "fatfs", "procfs", "sysfs"] +default = [] +use-ramdisk = [] # TODO: init ramdisk +fat = ["dep:fatfs"] +ext4 = ["dep:lwext4_rust"] +times = [] +std = ["lwext4_rust?/std"] [dependencies] +axalloc = { workspace = true } axdriver = { workspace = true, features = ["block"] } -axdriver_block.workspace = true -axerrno.workspace = true -axfs_devfs = { version = "0.1", optional = true } -axfs_ramfs = { version = "0.1", optional = true } -axfs_vfs = "0.1" +axerrno = { workspace = true } +axfs-ng-vfs = { workspace = true } +axhal = { workspace = true } axio = { workspace = true, features = ["alloc"] } -axsync.workspace = true -cap_access = "0.1" -cfg-if.workspace = true -crate_interface = { workspace = true, optional = true } -lazyinit.workspace = true -log.workspace = true -scope-local.workspace = true +axpoll = { workspace = true } +axsync = { workspace = true } +bitflags = "2.10" +cfg-if = { workspace = true } +chrono = { workspace = true } +intrusive-collections = "0.9.7" +kspin = { workspace = true } +log = { workspace = true } +lru = "0.16.0" +scope-local = { workspace = true } +slab = { version = "0.4.9", default-features = false } +spin = { workspace = true } -[dependencies.fatfs] -git = "https://github.com/rafalh/rust-fatfs" -rev = "4eccb50" -optional = true +[dependencies.lwext4_rust] +git = "https://github.com/Starry-OS/lwext4_rust.git" +rev = "033fa2c" default-features = false -features = [ # no std - "alloc", - "lfn", - "log_level_trace", - "unicode", -] +optional = true -[dependencies.lwext4_rust] -git = "https://github.com/Josen-B/lwext4_rust.git" -rev = "99b3e5c" +[dependencies.fatfs] +git = "https://github.com/Starry-OS/rust-fatfs.git" +rev = "2685439" +default-features = false optional = true +features = ["alloc", "lfn", "log_level_trace", "unicode"] [dev-dependencies] axdriver = { workspace = true, features = ["block", "ramdisk"] } -axdriver_block = { workspace = true, features = ["ramdisk"] } -axsync = { workspace = true, features = ["multitask"] } -axtask = { workspace = true, features = ["test"] } +env_logger = "0.11.8" diff --git a/modules/axfs/resources/create_test_img.sh b/modules/axfs/resources/create_test_img.sh deleted file mode 100755 index d51cef84d2..0000000000 --- a/modules/axfs/resources/create_test_img.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash - -# From https://github.com/rafalh/rust-fatfs/blob/master/scripts/create-test-img.sh - -CUR_DIR=`dirname $0` - -echo $OUT_DIR - -create_test_img() { - local name=$1 - local blkcount=$2 - local fatSize=$3 - dd if=/dev/zero of="$name" bs=1024 count=$blkcount - mkfs.vfat -s 1 -F $fatSize -n "Test!" -i 12345678 "$name" - mkdir -p mnt - sudo mount -o loop "$name" mnt -o rw,uid=$USER,gid=$USER - for i in $(seq 1 1000); do - echo "Rust is cool!" >>"mnt/long.txt" - done - echo "Rust is cool!" >>"mnt/short.txt" - mkdir -p "mnt/very/long/path" - echo "Rust is cool!" >>"mnt/very/long/path/test.txt" - mkdir -p "mnt/very-long-dir-name" - echo "Rust is cool!" >>"mnt/very-long-dir-name/very-long-file-name.txt" - - sudo umount mnt -} - -create_test_img "$CUR_DIR/fat16.img" 2500 16 -create_test_img "$CUR_DIR/fat32.img" 34000 32 diff --git a/modules/axfs/resources/fat16.img b/modules/axfs/resources/fat16.img deleted file mode 100644 index 6a220b4f3a..0000000000 Binary files a/modules/axfs/resources/fat16.img and /dev/null differ diff --git a/modules/axfs/resources/make_fs.sh b/modules/axfs/resources/make_fs.sh new file mode 100755 index 0000000000..0ac53dd072 --- /dev/null +++ b/modules/axfs/resources/make_fs.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +CUR_DIR=$(dirname $0) + +write_fs() { + for i in $(seq 1 1000); do + echo "Rust is cool!" >>long.txt + done + echo "Rust is cool!" >short.txt + mkdir -p a/long/path + echo "Rust is cool!" >a/long/path/test.txt + mkdir -p very-long-dir-name + echo "Rust is cool!" >>very-long-dir-name/very-long-file-name.txt +} +init_fs() { + local name=$1 + local options=$2 + mkdir -p mnt + sudo mount -o loop "$name" mnt -o "$options" + + sudo chmod 777 mnt + + cd mnt + write_fs + cd .. + + sudo umount mnt + rm -r mnt +} +create_fat_img() { + local name=$1 + local kb=$2 + local fatSize=$3 + dd if=/dev/zero of="$name" bs=1K count=$kb + mkfs.vfat -s 1 -F $fatSize "$name" + + init_fs "$name" rw,uid=$USER,gid=$USER +} +create_ext4_img() { + local name=$1 + local kb=$2 + dd if=/dev/zero of="$name" bs=1K count=$kb + mkfs.ext4 -O ^metadata_csum "$name" + + init_fs "$name" rw +} + +create_fat_img "$CUR_DIR/fat16.img" 2500 16 +create_fat_img "$CUR_DIR/fat32.img" 34000 32 + +create_ext4_img "$CUR_DIR/ext4.img" 30000 diff --git a/modules/axfs/src/api/dir.rs b/modules/axfs/src/api/dir.rs deleted file mode 100644 index a2711cfad7..0000000000 --- a/modules/axfs/src/api/dir.rs +++ /dev/null @@ -1,150 +0,0 @@ -use alloc::string::String; -use axio::Result; -use core::fmt; - -use super::FileType; -use crate::fops; - -/// Iterator over the entries in a directory. -pub struct ReadDir<'a> { - path: &'a str, - inner: fops::Directory, - buf_pos: usize, - buf_end: usize, - end_of_stream: bool, - dirent_buf: [fops::DirEntry; 31], -} - -/// Entries returned by the [`ReadDir`] iterator. -pub struct DirEntry<'a> { - dir_path: &'a str, - entry_name: String, - entry_type: FileType, -} - -/// A builder used to create directories in various manners. -#[derive(Default, Debug)] -pub struct DirBuilder { - recursive: bool, -} - -impl<'a> ReadDir<'a> { - pub(super) fn new(path: &'a str) -> Result { - let mut opts = fops::OpenOptions::new(); - opts.read(true); - let inner = fops::Directory::open_dir(path, &opts)?; - const EMPTY: fops::DirEntry = fops::DirEntry::default(); - let dirent_buf = [EMPTY; 31]; - Ok(ReadDir { - path, - inner, - end_of_stream: false, - buf_pos: 0, - buf_end: 0, - dirent_buf, - }) - } -} - -impl<'a> Iterator for ReadDir<'a> { - type Item = Result>; - - fn next(&mut self) -> Option>> { - if self.end_of_stream { - return None; - } - - loop { - if self.buf_pos >= self.buf_end { - match self.inner.read_dir(&mut self.dirent_buf) { - Ok(n) => { - if n == 0 { - self.end_of_stream = true; - return None; - } - self.buf_pos = 0; - self.buf_end = n; - } - Err(e) => { - self.end_of_stream = true; - return Some(Err(e)); - } - } - } - let entry = &self.dirent_buf[self.buf_pos]; - self.buf_pos += 1; - let name_bytes = entry.name_as_bytes(); - if name_bytes == b"." || name_bytes == b".." { - continue; - } - let entry_name = unsafe { core::str::from_utf8_unchecked(name_bytes).into() }; - let entry_type = entry.entry_type(); - - return Some(Ok(DirEntry { - dir_path: self.path, - entry_name, - entry_type, - })); - } - } -} - -impl DirEntry<'_> { - /// Returns the full path to the file that this entry represents. - /// - /// The full path is created by joining the original path to `read_dir` - /// with the filename of this entry. - pub fn path(&self) -> String { - String::from(self.dir_path.trim_end_matches('/')) + "/" + &self.entry_name - } - - /// Returns the bare file name of this directory entry without any other - /// leading path component. - pub fn file_name(&self) -> String { - self.entry_name.clone() - } - - /// Returns the file type for the file that this entry points at. - pub fn file_type(&self) -> FileType { - self.entry_type - } -} - -impl fmt::Debug for DirEntry<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("DirEntry").field(&self.path()).finish() - } -} - -impl DirBuilder { - /// Creates a new set of options with default mode/security settings for all - /// platforms and also non-recursive. - pub fn new() -> Self { - Self { recursive: false } - } - - /// Indicates that directories should be created recursively, creating all - /// parent directories. Parents that do not exist are created with the same - /// security and permissions settings. - pub fn recursive(&mut self, recursive: bool) -> &mut Self { - self.recursive = recursive; - self - } - - /// Creates the specified directory with the options configured in this - /// builder. - pub fn create(&self, path: &str) -> Result<()> { - if self.recursive { - self.create_dir_all(path) - } else { - crate::root::create_dir(None, path) - } - } - - fn create_dir_all(&self, _path: &str) -> Result<()> { - axerrno::ax_err!( - Unsupported, - "Recursive directory creation is not supported yet" - ) - } -} diff --git a/modules/axfs/src/api/file.rs b/modules/axfs/src/api/file.rs deleted file mode 100644 index eec1698548..0000000000 --- a/modules/axfs/src/api/file.rs +++ /dev/null @@ -1,193 +0,0 @@ -use axio::{Result, SeekFrom, prelude::*}; -use core::fmt; - -use crate::fops; - -/// A structure representing a type of file with accessors for each file type. -/// It is returned by [`Metadata::file_type`] method. -pub type FileType = fops::FileType; - -/// Representation of the various permissions on a file. -pub type Permissions = fops::FilePerm; - -/// An object providing access to an open file on the filesystem. -pub struct File { - inner: fops::File, -} - -/// Metadata information about a file. -pub struct Metadata(fops::FileAttr); - -/// Options and flags which can be used to configure how a file is opened. -#[derive(Clone, Debug)] -pub struct OpenOptions(fops::OpenOptions); - -impl Default for OpenOptions { - fn default() -> Self { - Self::new() - } -} - -impl OpenOptions { - /// Creates a blank new set of options ready for configuration. - pub const fn new() -> Self { - OpenOptions(fops::OpenOptions::new()) - } - - /// Sets the option for read access. - pub fn read(&mut self, read: bool) -> &mut Self { - self.0.read(read); - self - } - - /// Sets the option for write access. - pub fn write(&mut self, write: bool) -> &mut Self { - self.0.write(write); - self - } - - /// Sets the option for the append mode. - pub fn append(&mut self, append: bool) -> &mut Self { - self.0.append(append); - self - } - - /// Sets the option for truncating a previous file. - pub fn truncate(&mut self, truncate: bool) -> &mut Self { - self.0.truncate(truncate); - self - } - - /// Sets the option to create a new file, or open it if it already exists. - pub fn create(&mut self, create: bool) -> &mut Self { - self.0.create(create); - self - } - - /// Sets the option to create a new file, failing if it already exists. - pub fn create_new(&mut self, create_new: bool) -> &mut Self { - self.0.create_new(create_new); - self - } - - /// Opens a file at `path` with the options specified by `self`. - pub fn open(&self, path: &str) -> Result { - fops::File::open(path, &self.0).map(|inner| File { inner }) - } -} - -impl Metadata { - /// Returns the file type for this metadata. - pub const fn file_type(&self) -> FileType { - self.0.file_type() - } - - /// Returns `true` if this metadata is for a directory. The - /// result is mutually exclusive to the result of - /// [`Metadata::is_file`]. - pub const fn is_dir(&self) -> bool { - self.0.is_dir() - } - - /// Returns `true` if this metadata is for a regular file. The - /// result is mutually exclusive to the result of - /// [`Metadata::is_dir`]. - pub const fn is_file(&self) -> bool { - self.0.is_file() - } - - /// Returns the size of the file, in bytes, this metadata is for. - #[allow(clippy::len_without_is_empty)] - pub const fn len(&self) -> u64 { - self.0.size() - } - - /// Returns the permissions of the file this metadata is for. - pub const fn permissions(&self) -> Permissions { - self.0.perm() - } - - /// Returns the total size of this file in bytes. - pub const fn size(&self) -> u64 { - self.0.size() - } - - /// Returns the number of blocks allocated to the file, in 512-byte units. - pub const fn blocks(&self) -> u64 { - self.0.blocks() - } -} - -impl fmt::Debug for Metadata { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Metadata") - .field("file_type", &self.file_type()) - .field("is_dir", &self.is_dir()) - .field("is_file", &self.is_file()) - .field("permissions", &self.permissions()) - .finish_non_exhaustive() - } -} - -impl File { - /// Attempts to open a file in read-only mode. - pub fn open(path: &str) -> Result { - OpenOptions::new().read(true).open(path) - } - - /// Opens a file in write-only mode. - pub fn create(path: &str) -> Result { - OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(path) - } - - /// Creates a new file in read-write mode; error if the file exists. - pub fn create_new(path: &str) -> Result { - OpenOptions::new() - .read(true) - .write(true) - .create_new(true) - .open(path) - } - - /// Returns a new OpenOptions object. - pub fn options() -> OpenOptions { - OpenOptions::new() - } - - /// Truncates or extends the underlying file, updating the size of - /// this file to become `size`. - pub fn set_len(&self, size: u64) -> Result<()> { - self.inner.truncate(size) - } - - /// Queries metadata about the underlying file. - pub fn metadata(&self) -> Result { - self.inner.get_attr().map(Metadata) - } -} - -impl Read for File { - fn read(&mut self, buf: &mut [u8]) -> Result { - self.inner.read(buf) - } -} - -impl Write for File { - fn write(&mut self, buf: &[u8]) -> Result { - self.inner.write(buf) - } - - fn flush(&mut self) -> Result<()> { - self.inner.flush() - } -} - -impl Seek for File { - fn seek(&mut self, pos: SeekFrom) -> Result { - self.inner.seek(pos) - } -} diff --git a/modules/axfs/src/api/mod.rs b/modules/axfs/src/api/mod.rs deleted file mode 100644 index 7773c251bb..0000000000 --- a/modules/axfs/src/api/mod.rs +++ /dev/null @@ -1,89 +0,0 @@ -//! [`std::fs`]-like high-level filesystem manipulation operations. - -mod dir; -mod file; - -pub use self::dir::{DirBuilder, DirEntry, ReadDir}; -pub use self::file::{File, FileType, Metadata, OpenOptions, Permissions}; - -use alloc::{string::String, vec::Vec}; -use axio::{self as io, prelude::*}; - -/// Returns an iterator over the entries within a directory. -pub fn read_dir(path: &str) -> io::Result> { - ReadDir::new(path) -} - -/// Returns the canonical, absolute form of a path with all intermediate -/// components normalized. -pub fn canonicalize(path: &str) -> io::Result { - crate::root::absolute_path(path) -} - -/// Returns the current working directory as a [`String`]. -pub fn current_dir() -> io::Result { - crate::root::current_dir() -} - -/// Changes the current working directory to the specified path. -pub fn set_current_dir(path: &str) -> io::Result<()> { - crate::root::set_current_dir(path) -} - -/// Read the entire contents of a file into a bytes vector. -pub fn read(path: &str) -> io::Result> { - let mut file = File::open(path)?; - let size = file.metadata().map(|m| m.len()).unwrap_or(0); - let mut bytes = Vec::with_capacity(size as usize); - file.read_to_end(&mut bytes)?; - Ok(bytes) -} - -/// Read the entire contents of a file into a string. -pub fn read_to_string(path: &str) -> io::Result { - let mut file = File::open(path)?; - let size = file.metadata().map(|m| m.len()).unwrap_or(0); - let mut string = String::with_capacity(size as usize); - file.read_to_string(&mut string)?; - Ok(string) -} - -/// Write a slice as the entire contents of a file. -pub fn write>(path: &str, contents: C) -> io::Result<()> { - File::create(path)?.write_all(contents.as_ref()) -} - -/// Given a path, query the file system to get information about a file, -/// directory, etc. -pub fn metadata(path: &str) -> io::Result { - File::open(path)?.metadata() -} - -/// Creates a new, empty directory at the provided path. -pub fn create_dir(path: &str) -> io::Result<()> { - DirBuilder::new().create(path) -} - -/// Recursively create a directory and all of its parent components if they -/// are missing. -pub fn create_dir_all(path: &str) -> io::Result<()> { - DirBuilder::new().recursive(true).create(path) -} - -/// Removes an empty directory. -pub fn remove_dir(path: &str) -> io::Result<()> { - crate::root::remove_dir(None, path) -} - -/// Removes a file from the filesystem. -pub fn remove_file(path: &str) -> io::Result<()> { - crate::root::remove_file(None, path) -} - -/// Rename a file or directory to a new name. -/// Delete the original file if `old` already exists. -/// -/// This only works then the new path is in the same mounted fs. -pub fn rename(old: &str, new: &str) -> io::Result<()> { - crate::root::rename(old, new) -} diff --git a/modules/axfs/src/dev.rs b/modules/axfs/src/dev.rs deleted file mode 100644 index 47cc2d2f35..0000000000 --- a/modules/axfs/src/dev.rs +++ /dev/null @@ -1,92 +0,0 @@ -use axdriver::prelude::*; - -const BLOCK_SIZE: usize = 512; - -/// A disk device with a cursor. -pub struct Disk { - block_id: u64, - offset: usize, - dev: AxBlockDevice, -} - -impl Disk { - /// Create a new disk. - pub fn new(dev: AxBlockDevice) -> Self { - assert_eq!(BLOCK_SIZE, dev.block_size()); - Self { - block_id: 0, - offset: 0, - dev, - } - } - - /// Get the size of the disk. - pub fn size(&self) -> u64 { - self.dev.num_blocks() * BLOCK_SIZE as u64 - } - - /// Get the position of the cursor. - pub fn position(&self) -> u64 { - self.block_id * BLOCK_SIZE as u64 + self.offset as u64 - } - - /// Set the position of the cursor. - pub fn set_position(&mut self, pos: u64) { - self.block_id = pos / BLOCK_SIZE as u64; - self.offset = pos as usize % BLOCK_SIZE; - } - - /// Read within one block, returns the number of bytes read. - pub fn read_one(&mut self, buf: &mut [u8]) -> DevResult { - let read_size = if self.offset == 0 && buf.len() >= BLOCK_SIZE { - // whole block - self.dev - .read_block(self.block_id, &mut buf[0..BLOCK_SIZE])?; - self.block_id += 1; - BLOCK_SIZE - } else { - // partial block - let mut data = [0u8; BLOCK_SIZE]; - let start = self.offset; - let count = buf.len().min(BLOCK_SIZE - self.offset); - - self.dev.read_block(self.block_id, &mut data)?; - buf[..count].copy_from_slice(&data[start..start + count]); - - self.offset += count; - if self.offset >= BLOCK_SIZE { - self.block_id += 1; - self.offset -= BLOCK_SIZE; - } - count - }; - Ok(read_size) - } - - /// Write within one block, returns the number of bytes written. - pub fn write_one(&mut self, buf: &[u8]) -> DevResult { - let write_size = if self.offset == 0 && buf.len() >= BLOCK_SIZE { - // whole block - self.dev.write_block(self.block_id, &buf[0..BLOCK_SIZE])?; - self.block_id += 1; - BLOCK_SIZE - } else { - // partial block - let mut data = [0u8; BLOCK_SIZE]; - let start = self.offset; - let count = buf.len().min(BLOCK_SIZE - self.offset); - - self.dev.read_block(self.block_id, &mut data)?; - data[start..start + count].copy_from_slice(&buf[..count]); - self.dev.write_block(self.block_id, &data)?; - - self.offset += count; - if self.offset >= BLOCK_SIZE { - self.block_id += 1; - self.offset -= BLOCK_SIZE; - } - count - }; - Ok(write_size) - } -} diff --git a/modules/axfs/src/disk.rs b/modules/axfs/src/disk.rs new file mode 100644 index 0000000000..308181617e --- /dev/null +++ b/modules/axfs/src/disk.rs @@ -0,0 +1,165 @@ +use alloc::{boxed::Box, vec}; +use core::mem; + +use axdriver::prelude::*; + +fn take<'a>(buf: &mut &'a [u8], cnt: usize) -> &'a [u8] { + let (first, rem) = buf.split_at(cnt); + *buf = rem; + first +} + +fn take_mut<'a>(buf: &mut &'a mut [u8], cnt: usize) -> &'a mut [u8] { + // use mem::take to circumvent lifetime issues + let (first, rem) = mem::take(buf).split_at_mut(cnt); + *buf = rem; + first +} + +/// A disk device with a cursor. +pub struct SeekableDisk { + dev: AxBlockDevice, + + block_id: u64, + offset: usize, + block_size_log2: u8, + + read_buffer: Box<[u8]>, + write_buffer: Box<[u8]>, + /// Whether we have unsaved changes in the write buffer. + /// + /// It's guaranteed that when `offset == 0`, write_buffer_dirty is false. + write_buffer_dirty: bool, +} + +impl SeekableDisk { + /// Create a new disk. + pub fn new(dev: AxBlockDevice) -> Self { + assert!(dev.block_size().is_power_of_two()); + let block_size_log2 = dev.block_size().trailing_zeros() as u8; + let read_buffer = vec![0u8; dev.block_size()].into_boxed_slice(); + let write_buffer = vec![0u8; dev.block_size()].into_boxed_slice(); + Self { + dev, + block_id: 0, + offset: 0, + block_size_log2, + read_buffer, + write_buffer, + write_buffer_dirty: false, + } + } + + /// Get the size of the disk. + pub fn size(&self) -> u64 { + self.dev.num_blocks() << self.block_size_log2 + } + + /// Get the block size. + pub fn block_size(&self) -> usize { + 1 << self.block_size_log2 + } + + /// Get the position of the cursor. + pub fn position(&self) -> u64 { + (self.block_id << self.block_size_log2) + self.offset as u64 + } + + /// Set the position of the cursor. + pub fn set_position(&mut self, pos: u64) -> DevResult<()> { + self.flush()?; + self.block_id = pos >> self.block_size_log2; + self.offset = pos as usize & (self.block_size() - 1); + Ok(()) + } + + /// Write all pending changes to the disk. + pub fn flush(&mut self) -> DevResult<()> { + if self.write_buffer_dirty { + self.dev.write_block(self.block_id, &self.write_buffer)?; + self.write_buffer_dirty = false; + } + Ok(()) + } + + fn read_partial(&mut self, buf: &mut &mut [u8]) -> DevResult { + self.flush()?; + self.dev.read_block(self.block_id, &mut self.read_buffer)?; + + let data = &self.read_buffer[self.offset..]; + let length = buf.len().min(data.len()); + take_mut(buf, length).copy_from_slice(&data[..length]); + + self.offset += length; + if self.offset == self.block_size() { + self.block_id += 1; + self.offset = 0; + } + + Ok(length) + } + + /// Read from the disk, returns the number of bytes read. + pub fn read(&mut self, mut buf: &mut [u8]) -> DevResult { + let mut read = 0; + if self.offset != 0 { + read += self.read_partial(&mut buf)?; + } + if buf.len() >= self.block_size() { + let blocks = buf.len() >> self.block_size_log2; + let length = blocks << self.block_size_log2; + self.dev + .read_block(self.block_id, take_mut(&mut buf, length))?; + read += length; + + self.block_id += blocks as u64; + } + if !buf.is_empty() { + read += self.read_partial(&mut buf)?; + } + + Ok(read) + } + + fn write_partial(&mut self, buf: &mut &[u8]) -> DevResult { + if !self.write_buffer_dirty { + self.dev.read_block(self.block_id, &mut self.write_buffer)?; + self.write_buffer_dirty = true; + } + + let data = &mut self.write_buffer[self.offset..]; + let length = buf.len().min(data.len()); + data[..length].copy_from_slice(take(buf, length)); + + self.offset += length; + if self.offset == self.block_size() { + self.flush()?; + self.block_id += 1; + self.offset = 0; + } + + Ok(length) + } + + /// Write to the disk, returns the number of bytes written. + pub fn write(&mut self, mut buf: &[u8]) -> DevResult { + let mut written = 0; + if self.offset != 0 { + written += self.write_partial(&mut buf)?; + } + if buf.len() >= self.block_size() { + let blocks = buf.len() >> self.block_size_log2; + let length = blocks << self.block_size_log2; + self.dev + .write_block(self.block_id, take(&mut buf, length))?; + written += length; + + self.block_id += blocks as u64; + } + if !buf.is_empty() { + written += self.write_partial(&mut buf)?; + } + + Ok(written) + } +} diff --git a/modules/axfs/src/fops.rs b/modules/axfs/src/fops.rs deleted file mode 100644 index b2a0632450..0000000000 --- a/modules/axfs/src/fops.rs +++ /dev/null @@ -1,418 +0,0 @@ -//! Low-level filesystem operations. - -use axerrno::{AxError, AxResult, ax_err, ax_err_type}; -use axfs_vfs::{VfsError, VfsNodeRef}; -use axio::SeekFrom; -use cap_access::{Cap, WithCap}; -use core::fmt; - -#[cfg(feature = "myfs")] -pub use crate::dev::Disk; -#[cfg(feature = "myfs")] -pub use crate::fs::myfs::MyFileSystemIf; - -/// Alias of [`axfs_vfs::VfsNodeType`]. -pub type FileType = axfs_vfs::VfsNodeType; -/// Alias of [`axfs_vfs::VfsDirEntry`]. -pub type DirEntry = axfs_vfs::VfsDirEntry; -/// Alias of [`axfs_vfs::VfsNodeAttr`]. -pub type FileAttr = axfs_vfs::VfsNodeAttr; -/// Alias of [`axfs_vfs::VfsNodePerm`]. -pub type FilePerm = axfs_vfs::VfsNodePerm; - -/// An opened file object, with open permissions and a cursor. -pub struct File { - node: WithCap, - is_append: bool, - offset: u64, -} - -/// An opened directory object, with open permissions and a cursor for -/// [`read_dir`](Directory::read_dir). -pub struct Directory { - node: WithCap, - entry_idx: usize, -} - -/// Options and flags which can be used to configure how a file is opened. -#[derive(Clone)] -pub struct OpenOptions { - // generic - read: bool, - write: bool, - append: bool, - truncate: bool, - create: bool, - create_new: bool, - // system-specific - _custom_flags: i32, - _mode: u32, -} - -impl Default for OpenOptions { - fn default() -> Self { - Self::new() - } -} - -impl OpenOptions { - /// Creates a blank new set of options ready for configuration. - pub const fn new() -> Self { - Self { - // generic - read: false, - write: false, - append: false, - truncate: false, - create: false, - create_new: false, - // system-specific - _custom_flags: 0, - _mode: 0o666, - } - } - /// Sets the option for read access. - pub fn read(&mut self, read: bool) { - self.read = read; - } - /// Sets the option for write access. - pub fn write(&mut self, write: bool) { - self.write = write; - } - /// Sets the option for the append mode. - pub fn append(&mut self, append: bool) { - self.append = append; - } - /// Sets the option for truncating a previous file. - pub fn truncate(&mut self, truncate: bool) { - self.truncate = truncate; - } - /// Sets the option to create a new file, or open it if it already exists. - pub fn create(&mut self, create: bool) { - self.create = create; - } - /// Sets the option to create a new file, failing if it already exists. - pub fn create_new(&mut self, create_new: bool) { - self.create_new = create_new; - } - - const fn is_valid(&self) -> bool { - if !self.read && !self.write && !self.append { - return false; - } - match (self.write, self.append) { - (true, false) => {} - (false, false) => { - if self.truncate || self.create || self.create_new { - return false; - } - } - (_, true) => { - if self.truncate && !self.create_new { - return false; - } - } - } - true - } -} - -impl File { - fn access_node(&self, cap: Cap) -> AxResult<&VfsNodeRef> { - self.node.access_or_err(cap, AxError::PermissionDenied) - } - - fn _open_at(dir: Option<&VfsNodeRef>, path: &str, opts: &OpenOptions) -> AxResult { - debug!("open file: {path} {opts:?}"); - if !opts.is_valid() { - return ax_err!(InvalidInput); - } - - let node_option = crate::root::lookup(dir, path); - let node = if opts.create || opts.create_new { - match node_option { - Ok(node) => { - // already exists - if opts.create_new { - return ax_err!(AlreadyExists); - } - node - } - // not exists, create new - Err(VfsError::NotFound) => crate::root::create_file(dir, path)?, - Err(e) => return Err(e), - } - } else { - // just open the existing - node_option? - }; - - let attr = node.get_attr()?; - if attr.is_dir() - && (opts.create || opts.create_new || opts.write || opts.append || opts.truncate) - { - return ax_err!(IsADirectory); - } - let access_cap = opts.into(); - if !perm_to_cap(attr.perm()).contains(access_cap) { - return ax_err!(PermissionDenied); - } - - node.open()?; - if opts.truncate { - node.truncate(0)?; - } - Ok(Self { - node: WithCap::new(node, access_cap), - is_append: opts.append, - offset: 0, - }) - } - - /// Opens a file at the path relative to the current directory. Returns a - /// [`File`] object. - pub fn open(path: &str, opts: &OpenOptions) -> AxResult { - Self::_open_at(None, path, opts) - } - - /// Truncates the file to the specified size. - pub fn truncate(&self, size: u64) -> AxResult { - self.access_node(Cap::WRITE)?.truncate(size)?; - Ok(()) - } - - /// Reads the file at the current position. Returns the number of bytes - /// read. - /// - /// After the read, the cursor will be advanced by the number of bytes read. - pub fn read(&mut self, buf: &mut [u8]) -> AxResult { - let node = self.access_node(Cap::READ)?; - let read_len = node.read_at(self.offset, buf)?; - self.offset += read_len as u64; - Ok(read_len) - } - - /// Reads the file at the given position. Returns the number of bytes read. - /// - /// It does not update the file cursor. - pub fn read_at(&self, offset: u64, buf: &mut [u8]) -> AxResult { - let node = self.access_node(Cap::READ)?; - let read_len = node.read_at(offset, buf)?; - Ok(read_len) - } - - /// Writes the file at the current position. Returns the number of bytes - /// written. - /// - /// After the write, the cursor will be advanced by the number of bytes - /// written. - pub fn write(&mut self, buf: &[u8]) -> AxResult { - let offset = if self.is_append { - self.get_attr()?.size() - } else { - self.offset - }; - let node = self.access_node(Cap::WRITE)?; - let write_len = node.write_at(offset, buf)?; - self.offset = offset + write_len as u64; - Ok(write_len) - } - - /// Writes the file at the given position. Returns the number of bytes - /// written. - /// - /// It does not update the file cursor. - pub fn write_at(&self, offset: u64, buf: &[u8]) -> AxResult { - let node = self.access_node(Cap::WRITE)?; - let write_len = node.write_at(offset, buf)?; - Ok(write_len) - } - - /// Flushes the file, writes all buffered data to the underlying device. - pub fn flush(&self) -> AxResult { - self.access_node(Cap::WRITE)?.fsync()?; - Ok(()) - } - - /// Sets the cursor of the file to the specified offset. Returns the new - /// position after the seek. - pub fn seek(&mut self, pos: SeekFrom) -> AxResult { - let size = self.get_attr()?.size(); - let new_offset = match pos { - SeekFrom::Start(pos) => Some(pos), - SeekFrom::Current(off) => self.offset.checked_add_signed(off), - SeekFrom::End(off) => size.checked_add_signed(off), - } - .ok_or_else(|| ax_err_type!(InvalidInput))?; - self.offset = new_offset; - Ok(new_offset) - } - - /// Gets the file attributes. - pub fn get_attr(&self) -> AxResult { - self.access_node(Cap::empty())?.get_attr() - } -} - -impl Directory { - fn access_node(&self, cap: Cap) -> AxResult<&VfsNodeRef> { - self.node.access_or_err(cap, AxError::PermissionDenied) - } - - fn _open_dir_at(dir: Option<&VfsNodeRef>, path: &str, opts: &OpenOptions) -> AxResult { - debug!("open dir: {path}"); - if !opts.read { - return ax_err!(InvalidInput); - } - if opts.create || opts.create_new || opts.write || opts.append || opts.truncate { - return ax_err!(InvalidInput); - } - - let node = crate::root::lookup(dir, path)?; - let attr = node.get_attr()?; - if !attr.is_dir() { - return ax_err!(NotADirectory); - } - let access_cap = opts.into(); - if !perm_to_cap(attr.perm()).contains(access_cap) { - return ax_err!(PermissionDenied); - } - - node.open()?; - Ok(Self { - node: WithCap::new(node, access_cap), - entry_idx: 0, - }) - } - - fn access_at(&self, path: &str) -> AxResult> { - if path.starts_with('/') { - Ok(None) - } else { - Ok(Some(self.access_node(Cap::EXECUTE)?)) - } - } - - /// Opens a directory at the path relative to the current directory. - /// Returns a [`Directory`] object. - pub fn open_dir(path: &str, opts: &OpenOptions) -> AxResult { - Self::_open_dir_at(None, path, opts) - } - - /// Opens a directory at the path relative to this directory. Returns a - /// [`Directory`] object. - pub fn open_dir_at(&self, path: &str, opts: &OpenOptions) -> AxResult { - Self::_open_dir_at(self.access_at(path)?, path, opts) - } - - /// Opens a file at the path relative to this directory. Returns a [`File`] - /// object. - pub fn open_file_at(&self, path: &str, opts: &OpenOptions) -> AxResult { - File::_open_at(self.access_at(path)?, path, opts) - } - - /// Creates an empty file at the path relative to this directory. - pub fn create_file(&self, path: &str) -> AxResult { - crate::root::create_file(self.access_at(path)?, path) - } - - /// Creates an empty directory at the path relative to this directory. - pub fn create_dir(&self, path: &str) -> AxResult { - crate::root::create_dir(self.access_at(path)?, path) - } - - /// Removes a file at the path relative to this directory. - pub fn remove_file(&self, path: &str) -> AxResult { - crate::root::remove_file(self.access_at(path)?, path) - } - - /// Removes a directory at the path relative to this directory. - pub fn remove_dir(&self, path: &str) -> AxResult { - crate::root::remove_dir(self.access_at(path)?, path) - } - - /// Reads directory entries starts from the current position into the - /// given buffer. Returns the number of entries read. - /// - /// After the read, the cursor will be advanced by the number of entries - /// read. - pub fn read_dir(&mut self, dirents: &mut [DirEntry]) -> AxResult { - let n = self - .access_node(Cap::READ)? - .read_dir(self.entry_idx, dirents)?; - self.entry_idx += n; - Ok(n) - } - - /// Rename a file or directory to a new name. - /// Delete the original file if `old` already exists. - /// - /// This only works then the new path is in the same mounted fs. - pub fn rename(&self, old: &str, new: &str) -> AxResult { - crate::root::rename(old, new) - } -} - -impl Drop for File { - fn drop(&mut self) { - unsafe { self.node.access_unchecked().release().ok() }; - } -} - -impl Drop for Directory { - fn drop(&mut self) { - unsafe { self.node.access_unchecked().release().ok() }; - } -} - -impl fmt::Debug for OpenOptions { - #[allow(unused_assignments)] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut written = false; - macro_rules! fmt_opt { - ($field: ident, $label: literal) => { - if self.$field { - if written { - write!(f, " | ")?; - } - write!(f, $label)?; - written = true; - } - }; - } - fmt_opt!(read, "READ"); - fmt_opt!(write, "WRITE"); - fmt_opt!(append, "APPEND"); - fmt_opt!(truncate, "TRUNC"); - fmt_opt!(create, "CREATE"); - fmt_opt!(create_new, "CREATE_NEW"); - Ok(()) - } -} - -impl From<&OpenOptions> for Cap { - fn from(opts: &OpenOptions) -> Cap { - let mut cap = Cap::empty(); - if opts.read { - cap |= Cap::READ; - } - if opts.write | opts.append { - cap |= Cap::WRITE; - } - cap - } -} - -fn perm_to_cap(perm: FilePerm) -> Cap { - let mut cap = Cap::empty(); - if perm.owner_readable() { - cap |= Cap::READ; - } - if perm.owner_writable() { - cap |= Cap::WRITE; - } - if perm.owner_executable() { - cap |= Cap::EXECUTE; - } - cap -} diff --git a/modules/axfs/src/fs/ext4/fs.rs b/modules/axfs/src/fs/ext4/fs.rs new file mode 100644 index 0000000000..447b9ffcd7 --- /dev/null +++ b/modules/axfs/src/fs/ext4/fs.rs @@ -0,0 +1,79 @@ +use alloc::sync::Arc; +use core::cell::OnceCell; + +use axdriver::AxBlockDevice; +use axfs_ng_vfs::{ + DirEntry, DirNode, Filesystem, FilesystemOps, Reference, StatFs, VfsResult, path::MAX_NAME_LEN, +}; +use kspin::{SpinNoPreempt as Mutex, SpinNoPreemptGuard as MutexGuard}; +use lwext4_rust::{FsConfig, ffi::EXT4_ROOT_INO}; + +use super::{ + Ext4Disk, Inode, + util::{LwExt4Filesystem, into_vfs_err}, +}; + +const EXT4_CONFIG: FsConfig = FsConfig { bcache_size: 256 }; + +pub struct Ext4Filesystem { + inner: Mutex, + root_dir: OnceCell, +} + +impl Ext4Filesystem { + pub fn new(dev: AxBlockDevice) -> VfsResult { + let ext4 = + lwext4_rust::Ext4Filesystem::new(Ext4Disk(dev), EXT4_CONFIG).map_err(into_vfs_err)?; + + let fs = Arc::new(Self { + inner: Mutex::new(ext4), + root_dir: OnceCell::new(), + }); + let _ = fs.root_dir.set(DirEntry::new_dir( + |this| DirNode::new(Inode::new(fs.clone(), EXT4_ROOT_INO, Some(this))), + Reference::root(), + )); + Ok(Filesystem::new(fs)) + } + + pub(crate) fn lock(&self) -> MutexGuard { + self.inner.lock() + } +} + +unsafe impl Send for Ext4Filesystem {} + +unsafe impl Sync for Ext4Filesystem {} + +impl FilesystemOps for Ext4Filesystem { + fn name(&self) -> &str { + "ext4" + } + + fn root_dir(&self) -> DirEntry { + self.root_dir.get().unwrap().clone() + } + + fn stat(&self) -> VfsResult { + let mut fs = self.lock(); + let stat = fs.stat().map_err(into_vfs_err)?; + Ok(StatFs { + fs_type: 0xef53, + block_size: stat.block_size as _, + blocks: stat.blocks_count, + blocks_free: stat.free_blocks_count, + blocks_available: stat.free_blocks_count, + + file_count: stat.inodes_count as _, + free_file_count: stat.free_inodes_count as _, + + name_length: MAX_NAME_LEN as _, + fragment_size: 0, + mount_flags: 0, + }) + } + + fn flush(&self) -> VfsResult<()> { + self.inner.lock().flush().map_err(into_vfs_err) + } +} diff --git a/modules/axfs/src/fs/ext4/inode.rs b/modules/axfs/src/fs/ext4/inode.rs new file mode 100644 index 0000000000..68b24423a2 --- /dev/null +++ b/modules/axfs/src/fs/ext4/inode.rs @@ -0,0 +1,269 @@ +use alloc::{borrow::ToOwned, string::String, sync::Arc}; +use core::{any::Any, task::Context}; + +use axfs_ng_vfs::{ + DeviceId, DirEntry, DirEntrySink, DirNode, DirNodeOps, FileNode, FileNodeOps, FilesystemOps, + Metadata, MetadataUpdate, NodeFlags, NodeOps, NodePermission, NodeType, Reference, VfsError, + VfsResult, WeakDirEntry, +}; +use axpoll::{IoEvents, Pollable}; +use lwext4_rust::{FileAttr, InodeType}; + +use super::{ + Ext4Filesystem, + util::{LwExt4Filesystem, into_vfs_err, into_vfs_type}, +}; + +pub struct Inode { + fs: Arc, + ino: u32, + this: Option, +} + +impl Inode { + pub(crate) fn new(fs: Arc, ino: u32, this: Option) -> Arc { + Arc::new(Self { fs, ino, this }) + } + + fn create_entry(&self, entry: &lwext4_rust::DirEntry, name: impl Into) -> DirEntry { + let reference = Reference::new( + self.this.as_ref().and_then(WeakDirEntry::upgrade), + name.into(), + ); + if entry.inode_type() == InodeType::Directory { + DirEntry::new_dir( + |this| DirNode::new(Inode::new(self.fs.clone(), entry.ino(), Some(this))), + reference, + ) + } else { + DirEntry::new_file( + FileNode::new(Inode::new(self.fs.clone(), entry.ino(), None)), + into_vfs_type(entry.inode_type()), + reference, + ) + } + } + + fn lookup_locked(&self, fs: &mut LwExt4Filesystem, name: &str) -> VfsResult { + let mut result = fs.lookup(self.ino, name).map_err(into_vfs_err)?; + let entry = result.entry(); + Ok(self.create_entry(&entry, name)) + } + + fn update_ctime_locked(&self, fs: &mut LwExt4Filesystem, ino: u32) -> VfsResult<()> { + fs.with_inode_ref(ino, |ino| { + ino.update_ctime(); + Ok(()) + }) + .map_err(into_vfs_err) + } +} + +impl NodeOps for Inode { + fn inode(&self) -> u64 { + self.ino as _ + } + + fn metadata(&self) -> VfsResult { + let mut attr = FileAttr::default(); + self.fs + .lock() + .get_attr(self.ino, &mut attr) + .map_err(into_vfs_err)?; + Ok(Metadata { + inode: self.ino as _, + device: attr.device, + nlink: attr.nlink, + mode: NodePermission::from_bits_truncate(attr.mode as u16), + node_type: into_vfs_type(attr.node_type), + uid: attr.uid, + gid: attr.gid, + size: attr.size, + block_size: attr.block_size, + blocks: attr.blocks, + rdev: DeviceId::default(), + atime: attr.atime, + mtime: attr.mtime, + ctime: attr.ctime, + }) + } + + fn update_metadata(&self, update: MetadataUpdate) -> VfsResult<()> { + let mut fs = self.fs.lock(); + fs.with_inode_ref(self.ino, |inode| { + if let Some(mode) = update.mode { + inode.set_mode((inode.mode() & !0xfff) | (mode.bits() as u32)); + } + if let Some((uid, gid)) = update.owner { + inode.set_owner(uid as _, gid as _); + } + if let Some(atime) = update.atime { + inode.set_atime(&atime); + } + if let Some(mtime) = update.mtime { + inode.set_mtime(&mtime); + } + inode.update_ctime(); + Ok(()) + }) + .map_err(into_vfs_err)?; + Ok(()) + } + + fn len(&self) -> VfsResult { + self.fs + .lock() + .with_inode_ref(self.ino, |inode| Ok(inode.size())) + .map_err(into_vfs_err) + } + + fn filesystem(&self) -> &dyn FilesystemOps { + &*self.fs + } + + fn sync(&self, _data_only: bool) -> VfsResult<()> { + Ok(()) + } + + fn into_any(self: Arc) -> Arc { + self + } + + fn flags(&self) -> NodeFlags { + NodeFlags::BLOCKING + } +} + +impl FileNodeOps for Inode { + fn read_at(&self, buf: &mut [u8], offset: u64) -> VfsResult { + self.fs + .lock() + .read_at(self.ino, buf, offset) + .map_err(into_vfs_err) + } + + fn write_at(&self, buf: &[u8], offset: u64) -> VfsResult { + self.fs + .lock() + .write_at(self.ino, buf, offset) + .map_err(into_vfs_err) + } + + fn append(&self, buf: &[u8]) -> VfsResult<(usize, u64)> { + let mut fs = self.fs.lock(); + let length = fs + .with_inode_ref(self.ino, |inode| Ok(inode.size())) + .map_err(into_vfs_err)?; + let written = fs.write_at(self.ino, buf, length).map_err(into_vfs_err)?; + Ok((written, length + written as u64)) + } + + fn set_len(&self, len: u64) -> VfsResult<()> { + self.fs.lock().set_len(self.ino, len).map_err(into_vfs_err) + } + + fn set_symlink(&self, target: &str) -> VfsResult<()> { + self.fs + .lock() + .set_symlink(self.ino, target.as_bytes()) + .map_err(into_vfs_err) + } +} + +impl Pollable for Inode { + fn poll(&self) -> IoEvents { + IoEvents::IN | IoEvents::OUT + } + + fn register(&self, _context: &mut Context<'_>, _events: IoEvents) {} +} + +impl DirNodeOps for Inode { + fn read_dir(&self, offset: u64, sink: &mut dyn DirEntrySink) -> VfsResult { + let mut fs = self.fs.lock(); + let mut reader = fs.read_dir(self.ino, offset).map_err(into_vfs_err)?; + let mut count = 0; + while let Some(entry) = reader.current() { + let name = core::str::from_utf8(entry.name()) + .map_err(|_| VfsError::InvalidData)? + .to_owned(); + let ino = entry.ino() as u64; + let node_type = into_vfs_type(entry.inode_type()); + reader.step().map_err(into_vfs_err)?; + if !sink.accept(&name, ino, node_type, reader.offset()) { + break; + } + count += 1; + } + Ok(count) + } + + fn lookup(&self, name: &str) -> VfsResult { + let mut fs = self.fs.lock(); + self.lookup_locked(&mut fs, name) + } + + fn create( + &self, + name: &str, + node_type: NodeType, + permission: NodePermission, + ) -> VfsResult { + let inode_type = match node_type { + NodeType::Fifo => InodeType::Fifo, + NodeType::CharacterDevice => InodeType::CharacterDevice, + NodeType::Directory => InodeType::Directory, + NodeType::BlockDevice => InodeType::BlockDevice, + NodeType::RegularFile => InodeType::RegularFile, + NodeType::Symlink => InodeType::Symlink, + NodeType::Socket => InodeType::Socket, + NodeType::Unknown => { + return Err(VfsError::InvalidData); + } + }; + let mut fs = self.fs.lock(); + if fs.lookup(self.ino, name).is_ok() { + return Err(VfsError::AlreadyExists); + } + let ino = fs + .create(self.ino, name, inode_type, permission.bits() as _) + .map_err(into_vfs_err)?; + self.update_ctime_locked(&mut fs, ino)?; + + let reference = Reference::new( + self.this.as_ref().and_then(WeakDirEntry::upgrade), + name.to_owned(), + ); + Ok(if node_type == NodeType::Directory { + DirEntry::new_dir( + |this| DirNode::new(Inode::new(self.fs.clone(), ino, Some(this))), + reference, + ) + } else { + DirEntry::new_file( + FileNode::new(Inode::new(self.fs.clone(), ino, None)), + node_type, + reference, + ) + }) + } + + fn link(&self, name: &str, node: &DirEntry) -> VfsResult { + let mut fs = self.fs.lock(); + fs.link(self.ino, name, node.inode() as _) + .map_err(into_vfs_err)?; + self.update_ctime_locked(&mut fs, node.inode() as _)?; + self.lookup_locked(&mut fs, name) + } + + fn unlink(&self, name: &str) -> VfsResult<()> { + self.fs.lock().unlink(self.ino, name).map_err(into_vfs_err) + } + + fn rename(&self, src_name: &str, dst_dir: &DirNode, dst_name: &str) -> VfsResult<()> { + let dst_dir: Arc = dst_dir.downcast().map_err(|_| VfsError::InvalidInput)?; + let mut fs = self.fs.lock(); + fs.rename(self.ino, src_name, dst_dir.ino, dst_name) + .map_err(into_vfs_err) + } +} diff --git a/modules/axfs/src/fs/ext4/mod.rs b/modules/axfs/src/fs/ext4/mod.rs new file mode 100644 index 0000000000..9916e9ca60 --- /dev/null +++ b/modules/axfs/src/fs/ext4/mod.rs @@ -0,0 +1,31 @@ +mod fs; +mod inode; +mod util; + +#[allow(unused_imports)] +use axdriver::{AxBlockDevice, prelude::BlockDriverOps}; +pub use fs::*; +pub use inode::*; +use lwext4_rust::{BlockDevice, Ext4Error, Ext4Result, ffi::EIO}; + +pub(crate) struct Ext4Disk(AxBlockDevice); + +impl BlockDevice for Ext4Disk { + fn read_blocks(&mut self, block_id: u64, buf: &mut [u8]) -> Ext4Result { + self.0 + .read_block(block_id, buf) + .map_err(|_| Ext4Error::new(EIO as _, None))?; + Ok(buf.len()) + } + + fn write_blocks(&mut self, block_id: u64, buf: &[u8]) -> Ext4Result { + self.0 + .write_block(block_id, buf) + .map_err(|_| Ext4Error::new(EIO as _, None))?; + Ok(buf.len()) + } + + fn num_blocks(&self) -> Ext4Result { + Ok(self.0.num_blocks()) + } +} diff --git a/modules/axfs/src/fs/ext4/util.rs b/modules/axfs/src/fs/ext4/util.rs new file mode 100644 index 0000000000..5f9e0f8f41 --- /dev/null +++ b/modules/axfs/src/fs/ext4/util.rs @@ -0,0 +1,36 @@ +use axerrno::LinuxError; +use axfs_ng_vfs::{NodeType, VfsError}; +use lwext4_rust::{Ext4Error, InodeType, SystemHal}; + +use super::Ext4Disk; + +pub struct AxHal; +impl SystemHal for AxHal { + fn now() -> Option { + if cfg!(feature = "times") { + Some(axhal::time::wall_time()) + } else { + None + } + } +} + +pub type LwExt4Filesystem = lwext4_rust::Ext4Filesystem; + +pub fn into_vfs_err(err: Ext4Error) -> VfsError { + let linux_error = LinuxError::try_from(err.code).unwrap_or(LinuxError::EIO); + VfsError::from(linux_error).canonicalize() +} + +pub fn into_vfs_type(ty: InodeType) -> NodeType { + match ty { + InodeType::RegularFile => NodeType::RegularFile, + InodeType::Directory => NodeType::Directory, + InodeType::CharacterDevice => NodeType::CharacterDevice, + InodeType::BlockDevice => NodeType::BlockDevice, + InodeType::Fifo => NodeType::Fifo, + InodeType::Socket => NodeType::Socket, + InodeType::Symlink => NodeType::Symlink, + InodeType::Unknown => NodeType::Unknown, + } +} diff --git a/modules/axfs/src/fs/ext4fs.rs b/modules/axfs/src/fs/ext4fs.rs deleted file mode 100644 index e4f0ee9a26..0000000000 --- a/modules/axfs/src/fs/ext4fs.rs +++ /dev/null @@ -1,381 +0,0 @@ -use alloc::sync::Arc; - -use axdriver_block::DevError; -use axerrno::AxError; -use axfs_vfs::{ - VfsDirEntry, VfsError, VfsNodeAttr, VfsNodeOps, VfsNodePerm, VfsNodeRef, VfsNodeType, VfsOps, - VfsResult, -}; -use axsync::Mutex; -use lwext4_rust::{ - Ext4BlockWrapper, Ext4File, InodeTypes, KernelDevOp, - bindings::{O_CREAT, O_RDONLY, O_RDWR, O_TRUNC, O_WRONLY, SEEK_CUR, SEEK_END, SEEK_SET}, -}; - -use crate::{ - alloc::string::{String, ToString}, - dev::Disk, -}; - -const BLOCK_SIZE: usize = 512; - -#[allow(dead_code)] -pub struct Ext4FileSystem { - inner: Ext4BlockWrapper, - root: VfsNodeRef, -} - -unsafe impl Sync for Ext4FileSystem {} -unsafe impl Send for Ext4FileSystem {} - -impl Ext4FileSystem { - #[cfg(feature = "use-ramdisk")] - pub fn new(mut disk: Disk) -> Self { - unimplemented!() - } - - #[cfg(not(feature = "use-ramdisk"))] - pub fn new(disk: Disk) -> Self { - info!( - "Got Disk size:{}, position:{}", - disk.size(), - disk.position() - ); - let inner = - Ext4BlockWrapper::::new(disk).expect("failed to initialize EXT4 filesystem"); - let root = Arc::new(FileWrapper::new("/", InodeTypes::EXT4_DE_DIR)); - Self { inner, root } - } -} - -/// The [`VfsOps`] trait provides operations on a filesystem. -impl VfsOps for Ext4FileSystem { - fn root_dir(&self) -> VfsNodeRef { - debug!("Get root_dir"); - Arc::clone(&self.root) - } -} - -pub struct FileWrapper(Mutex); - -unsafe impl Send for FileWrapper {} -unsafe impl Sync for FileWrapper {} - -impl FileWrapper { - fn new(path: &str, types: InodeTypes) -> Self { - info!("FileWrapper new {:?} {}", types, path); - Self(Mutex::new(Ext4File::new(path, types))) - } - - fn path_deal_with(&self, path: &str) -> String { - if path.starts_with('/') { - debug!("path_deal_with: {}", path); - } - let trim_path = path.trim_matches('/'); - if trim_path.is_empty() || trim_path == "." { - return String::new(); - } - - if let Some(rest) = trim_path.strip_prefix("./") { - //if starts with "./" - return self.path_deal_with(rest); - } - let rest_p = trim_path.replace("//", "/"); - if trim_path != rest_p { - return self.path_deal_with(&rest_p); - } - let file = self.0.lock(); - let path = file.get_path(); - let fpath = String::from(path.to_str().unwrap().trim_end_matches('/')) + "/" + trim_path; - debug!("dealt with full path: {}", fpath.as_str()); - fpath - } -} - -/// The [`VfsNodeOps`] trait provides operations on a file or a directory. -impl VfsNodeOps for FileWrapper { - fn get_attr(&self) -> VfsResult { - let mut file = self.0.lock(); - let perm = file.file_mode_get().unwrap_or(0o755); - let perm = VfsNodePerm::from_bits_truncate((perm as u16) & 0o777); - let vtype = file.file_type_get(); - let vtype = match vtype { - InodeTypes::EXT4_INODE_MODE_FIFO => VfsNodeType::Fifo, - InodeTypes::EXT4_INODE_MODE_CHARDEV => VfsNodeType::CharDevice, - InodeTypes::EXT4_INODE_MODE_DIRECTORY => VfsNodeType::Dir, - InodeTypes::EXT4_INODE_MODE_BLOCKDEV => VfsNodeType::BlockDevice, - InodeTypes::EXT4_INODE_MODE_FILE => VfsNodeType::File, - InodeTypes::EXT4_INODE_MODE_SOFTLINK => VfsNodeType::SymLink, - InodeTypes::EXT4_INODE_MODE_SOCKET => VfsNodeType::Socket, - _ => { - warn!("unknown file type: {:?}", vtype); - VfsNodeType::File - } - }; - let size = if vtype == VfsNodeType::File { - let path = file.get_path().to_str().unwrap().to_string(); - file.file_open(&path, O_RDONLY) - .map_err(|e| >::try_into(e).unwrap())?; - let fsize = file.file_size(); - file.file_close().expect("failed to close fd"); - fsize - } else { - 0 - }; - let blocks = (size + (BLOCK_SIZE as u64 - 1)) / BLOCK_SIZE as u64; - trace!( - "get_attr of {:?} {:?}, size: {}, blocks: {}", - vtype, - file.get_path(), - size, - blocks - ); - - Ok(VfsNodeAttr::new(perm, vtype, size, blocks)) - } - - fn create(&self, path: &str, ty: VfsNodeType) -> VfsResult { - debug!("create {:?} on Ext4fs: {}", ty, path); - let fpath = self.path_deal_with(path); - let fpath = fpath.as_str(); - if fpath.is_empty() { - return Ok(()); - } - let types = match ty { - VfsNodeType::Fifo => InodeTypes::EXT4_DE_FIFO, - VfsNodeType::CharDevice => InodeTypes::EXT4_DE_CHRDEV, - VfsNodeType::Dir => InodeTypes::EXT4_DE_DIR, - VfsNodeType::BlockDevice => InodeTypes::EXT4_DE_BLKDEV, - VfsNodeType::File => InodeTypes::EXT4_DE_REG_FILE, - VfsNodeType::SymLink => InodeTypes::EXT4_DE_SYMLINK, - VfsNodeType::Socket => InodeTypes::EXT4_DE_SOCK, - }; - - let mut file = self.0.lock(); - if file.check_inode_exist(fpath, types.clone()) { - Ok(()) - } else { - if types == InodeTypes::EXT4_DE_DIR { - file.dir_mk(fpath) - .map(|_v| ()) - .map_err(|e| e.try_into().unwrap()) - } else { - file.file_open(fpath, O_WRONLY | O_CREAT | O_TRUNC) - .expect("create file failed"); - file.file_close() - .map(|_v| ()) - .map_err(|e| e.try_into().unwrap()) - } - } - } - - fn remove(&self, path: &str) -> VfsResult { - debug!("remove ext4fs: {}", path); - let fpath = self.path_deal_with(path); - let fpath = fpath.as_str(); - assert!(!fpath.is_empty()); // already check at `root.rs` - let mut file = self.0.lock(); - if file.check_inode_exist(fpath, InodeTypes::EXT4_DE_DIR) { - // Recursive directory remove - file.dir_rm(fpath) - .map(|_v| ()) - .map_err(|e| e.try_into().unwrap()) - } else { - file.file_remove(fpath) - .map(|_v| ()) - .map_err(|e| e.try_into().unwrap()) - } - } - - /// Get the parent directory of this directory. - /// Return `None` if the node is a file. - fn parent(&self) -> Option { - let file = self.0.lock(); - if file.get_type() == InodeTypes::EXT4_DE_DIR { - let path = file.get_path().to_str().unwrap().to_string(); - debug!("Get the parent dir of {}", path); - let path = path.trim_end_matches('/').trim_end_matches(|c| c != '/'); - if !path.is_empty() { - return Some(Arc::new(Self::new(path, InodeTypes::EXT4_DE_DIR))); - } - } - None - } - - /// Read directory entries into `dirents`, starting from `start_idx`. - fn read_dir(&self, start_idx: usize, dirents: &mut [VfsDirEntry]) -> VfsResult { - let file = self.0.lock(); - let (names, inode_types) = file.lwext4_dir_entries().unwrap(); - for (i, out_entry) in dirents.iter_mut().enumerate() { - let iname = names.get(start_idx + i); - let itype = inode_types.get(start_idx + i); - match (iname, itype) { - (Some(name), Some(t)) => { - let ty = match t { - InodeTypes::EXT4_DE_DIR => VfsNodeType::Dir, - InodeTypes::EXT4_DE_REG_FILE => VfsNodeType::File, - InodeTypes::EXT4_DE_SYMLINK => VfsNodeType::SymLink, - _ => { - error!("unknown file type: {:?}", t); - unreachable!() - } - }; - *out_entry = VfsDirEntry::new(core::str::from_utf8(name).unwrap(), ty); - } - _ => return Ok(i), - } - } - Ok(dirents.len()) - } - - /// Lookup the node with given `path` in the directory. - /// Return the node if found. - fn lookup(self: Arc, path: &str) -> VfsResult { - trace!("lookup ext4fs: {:?}, {}", self.0.lock().get_path(), path); - let fpath = self.path_deal_with(path); - let fpath = fpath.as_str(); - if fpath.is_empty() { - return Ok(self.clone()); - } - let mut file = self.0.lock(); - if file.check_inode_exist(fpath, InodeTypes::EXT4_DE_DIR) { - trace!("lookup new DIR FileWrapper"); - Ok(Arc::new(Self::new(fpath, InodeTypes::EXT4_DE_DIR))) - } else if file.check_inode_exist(fpath, InodeTypes::EXT4_DE_REG_FILE) { - trace!("lookup new FILE FileWrapper"); - Ok(Arc::new(Self::new(fpath, InodeTypes::EXT4_DE_REG_FILE))) - } else { - Err(VfsError::NotFound) - } - } - - fn read_at(&self, offset: u64, buf: &mut [u8]) -> VfsResult { - trace!("To read_at {}, buf len={}", offset, buf.len()); - let mut file = self.0.lock(); - let path = file.get_path().to_str().unwrap().to_string(); - file.file_open(&path, O_RDONLY) - .map_err(|e| >::try_into(e).unwrap())?; - - file.file_seek(offset as i64, SEEK_SET) - .map_err(|e| >::try_into(e).unwrap())?; - let result = file.file_read(buf); - file.file_close().expect("failed to close fd"); - result.map_err(|e| e.try_into().unwrap()) - } - - fn write_at(&self, offset: u64, buf: &[u8]) -> VfsResult { - trace!("To write_at {}, buf len={}", offset, buf.len()); - let mut file = self.0.lock(); - let path = file.get_path().to_str().unwrap().to_string(); - file.file_open(&path, O_RDWR) - .map_err(|e| >::try_into(e).unwrap())?; - - file.file_seek(offset as i64, SEEK_SET) - .map_err(|e| >::try_into(e).unwrap())?; - let result = file.file_write(buf); - file.file_close().expect("failed to close fd"); - result.map_err(|e| e.try_into().unwrap()) - } - - fn truncate(&self, size: u64) -> VfsResult { - debug!("truncate file to size={}", size); - let mut file = self.0.lock(); - let path = file.get_path().to_str().unwrap().to_string(); - file.file_open(&path, O_RDWR | O_CREAT | O_TRUNC) - .map_err(|e| >::try_into(e).unwrap())?; - - let result = file.file_truncate(size); - file.file_close().expect("failed to close fd"); - result.map(|_| ()).map_err(|e| e.try_into().unwrap()) - } - - fn rename(&self, src_path: &str, dst_path: &str) -> VfsResult { - debug!("rename from {} to {}", src_path, dst_path); - let mut file = self.0.lock(); - file.file_rename(src_path, dst_path) - .map(|_| ()) - .map_err(|e| e.try_into().unwrap()) - } - - fn as_any(&self) -> &dyn core::any::Any { - self as &dyn core::any::Any - } -} - -impl Drop for FileWrapper { - fn drop(&mut self) { - let mut file = self.0.lock(); - debug!("Drop struct FileWrapper {:?}", file.get_path()); - file.file_close().expect("failed to close fd"); - drop(file); // todo - } -} - -impl KernelDevOp for Disk { - type DevType = Disk; - - fn read(dev: &mut Disk, mut buf: &mut [u8]) -> Result { - trace!("READ block device buf={}", buf.len()); - let mut read_len = 0; - while !buf.is_empty() { - match dev.read_one(buf) { - Ok(0) => break, - Ok(n) => { - buf = &mut buf[n..]; - read_len += n; - } - Err(_) => return Err(DevError::Io as i32), - } - } - trace!("READ rt len={}", read_len); - Ok(read_len) - } - - fn write(dev: &mut Self::DevType, mut buf: &[u8]) -> Result { - trace!("WRITE block device buf={}", buf.len()); - let mut write_len = 0; - while !buf.is_empty() { - match dev.write_one(buf) { - Ok(0) => break, - Ok(n) => { - buf = &buf[n..]; - write_len += n; - } - Err(_e) => return Err(DevError::Io as i32), - } - } - trace!("WRITE rt len={}", write_len); - Ok(write_len) - } - - fn flush(_dev: &mut Self::DevType) -> Result { - debug!("uncomplicated"); - Ok(0) - } - - fn seek(dev: &mut Disk, off: i64, whence: i32) -> Result { - let size = dev.size(); - trace!( - "SEEK block device size:{}, pos:{}, offset={}, whence={}", - size, - &dev.position(), - off, - whence - ); - let new_pos = match whence as u32 { - SEEK_SET => Some(off), - SEEK_CUR => dev.position().checked_add_signed(off).map(|v| v as i64), - SEEK_END => size.checked_add_signed(off).map(|v| v as i64), - _ => { - error!("invalid seek() whence: {}", whence); - Some(off) - } - } - .ok_or(DevError::Io as i32)?; - if new_pos as u64 > size { - warn!("Seek beyond the end of the block device"); - } - dev.set_position(new_pos as u64); - Ok(new_pos) - } -} diff --git a/modules/axfs/src/fs/fat/dir.rs b/modules/axfs/src/fs/fat/dir.rs new file mode 100644 index 0000000000..83664441bf --- /dev/null +++ b/modules/axfs/src/fs/fat/dir.rs @@ -0,0 +1,220 @@ +use alloc::{string::String, sync::Arc}; +use core::{any::Any, mem, ops::Deref, time::Duration}; + +use axfs_ng_vfs::{ + DeviceId, DirEntry, DirEntrySink, DirNode, DirNodeOps, FilesystemOps, Metadata, MetadataUpdate, + NodeFlags, NodeOps, NodePermission, NodeType, Reference, VfsError, VfsResult, WeakDirEntry, +}; + +use super::{ + FsRef, ff, + file::FatFileNode, + fs::FatFilesystem, + util::{file_metadata, into_vfs_err}, +}; + +pub struct FatDirNode { + fs: Arc, + pub(crate) inner: FsRef>, + inode: u64, + this: WeakDirEntry, +} + +impl FatDirNode { + pub fn new(fs: Arc, dir: ff::Dir, inode: u64, this: WeakDirEntry) -> DirNode { + DirNode::new(Arc::new(Self { + fs, + // SAFETY: FsRef guarantees correct lifetime + inner: FsRef::new(unsafe { mem::transmute::(dir) }), + inode, + this, + })) + } + + fn create_entry(&self, entry: ff::DirEntry, name: impl Into, inode: u64) -> DirEntry { + let reference = Reference::new(self.this.upgrade(), name.into()); + if entry.is_file() { + DirEntry::new_file( + FatFileNode::new(self.fs.clone(), entry.to_file(), inode), + NodeType::RegularFile, + reference, + ) + } else { + DirEntry::new_dir( + |this| FatDirNode::new(self.fs.clone(), entry.to_dir(), inode, this), + reference, + ) + } + } +} + +unsafe impl Send for FatDirNode {} + +unsafe impl Sync for FatDirNode {} + +impl NodeOps for FatDirNode { + fn inode(&self) -> u64 { + self.inode + } + + fn metadata(&self) -> VfsResult { + let fs = self.fs.lock(); + let dir = self.inner.borrow(&fs); + if let Some(file) = dir.as_file() { + return Ok(file_metadata(&fs, file, NodeType::Directory)); + } + + // root directory + let block_size = fs.inner.bytes_per_sector() as u64; + Ok(Metadata { + inode: self.inode(), + device: 0, + nlink: 1, + mode: NodePermission::default(), + node_type: NodeType::Directory, + uid: 0, + gid: 0, + size: block_size, + block_size, + blocks: 1, + rdev: DeviceId::default(), + atime: Duration::default(), + mtime: Duration::default(), + ctime: Duration::default(), + }) + } + + fn update_metadata(&self, _update: MetadataUpdate) -> VfsResult<()> { + // TODO: update metadata on directory + Ok(()) + } + + fn filesystem(&self) -> &dyn FilesystemOps { + self.fs.deref() + } + + fn sync(&self, _data_only: bool) -> VfsResult<()> { + Ok(()) + } + + fn into_any(self: Arc) -> Arc { + self + } + + fn flags(&self) -> NodeFlags { + NodeFlags::BLOCKING + } +} + +impl DirNodeOps for FatDirNode { + fn read_dir(&self, offset: u64, sink: &mut dyn DirEntrySink) -> VfsResult { + let mut fs = self.fs.lock(); + let dir = self.inner.borrow(&fs); + let this_entry = self.this.upgrade().unwrap(); + let dir_node = this_entry.as_dir()?; + + let mut count = 0; + for entry in dir.iter().skip(offset as usize) { + let entry = entry.map_err(into_vfs_err)?; + let name = entry.file_name().to_ascii_lowercase(); + let node_type = if entry.is_file() { + NodeType::RegularFile + } else { + NodeType::Directory + }; + let inode = if let Some(entry) = dir_node.lookup_cache(&name) { + entry.inode() + } else { + let entry = self.create_entry(entry, name.clone(), fs.alloc_inode()); + let inode = entry.inode(); + dir_node.insert_cache(name.clone(), entry); + inode + }; + if !sink.accept(&name, inode, node_type, offset + count + 1) { + break; + } + count += 1; + } + Ok(count as usize) + } + + fn lookup(&self, name: &str) -> VfsResult { + let mut fs = self.fs.lock(); + let dir = self.inner.borrow(&fs); + dir.iter() + .find_map(|entry| entry.ok().filter(|it| it.eq_name(name))) + .map(|entry| self.create_entry(entry, name.to_ascii_lowercase(), fs.alloc_inode())) + .ok_or(VfsError::NotFound) + } + + fn create( + &self, + name: &str, + node_type: NodeType, + _permission: NodePermission, + ) -> VfsResult { + let mut fs = self.fs.lock(); + let dir = self.inner.borrow(&fs); + let reference = Reference::new(self.this.upgrade(), name.to_ascii_lowercase()); + match node_type { + NodeType::RegularFile => dir + .create_file(name) + .map(|file| { + DirEntry::new_file( + FatFileNode::new(self.fs.clone(), file, fs.alloc_inode()), + NodeType::RegularFile, + reference, + ) + }) + .map_err(into_vfs_err), + NodeType::Directory => dir + .create_dir(name) + .map(|dir| { + DirEntry::new_dir( + |this| FatDirNode::new(self.fs.clone(), dir, fs.alloc_inode(), this), + reference, + ) + }) + .map_err(into_vfs_err), + _ => Err(VfsError::InvalidInput), + } + } + + fn link(&self, _name: &str, _node: &DirEntry) -> VfsResult { + // EPERM The filesystem containing oldpath and newpath does not + // support the creation of hard links. + Err(VfsError::PermissionDenied) + } + + fn unlink(&self, name: &str) -> VfsResult<()> { + let fs = self.fs.lock(); + let dir = self.inner.borrow(&fs); + dir.remove(name).map_err(into_vfs_err) + } + + fn rename(&self, src_name: &str, dst_dir: &DirNode, dst_name: &str) -> VfsResult<()> { + let fs = self.fs.lock(); + let dst_dir: Arc = dst_dir.downcast().map_err(|_| VfsError::InvalidInput)?; + + let dir = self.inner.borrow(&fs); + + // The default implementation throws EEXIST if dst exists, so we need to + // handle it + match dst_dir.inner.borrow(&fs).remove(dst_name) { + Ok(_) => { + warn!("对 I removed {}", dst_name); + } + Err(fatfs::Error::NotFound) => {} + Err(err) => return Err(into_vfs_err(err)), + } + + dir.rename(src_name, dst_dir.inner.borrow(&fs), dst_name) + .map_err(into_vfs_err) + } +} + +impl Drop for FatDirNode { + fn drop(&mut self) { + self.fs.lock().release_inode(self.inode); + } +} diff --git a/modules/axfs/src/fs/fat/ff.rs b/modules/axfs/src/fs/fat/ff.rs new file mode 100644 index 0000000000..77b07b3689 --- /dev/null +++ b/modules/axfs/src/fs/fat/ff.rs @@ -0,0 +1,13 @@ +//! Type aliases for `fatfs`. + +use fatfs::{DefaultTimeProvider, LossyOemCpConverter}; + +use crate::disk::SeekableDisk; + +pub type FileSystem = fatfs::FileSystem; + +pub type Dir<'a> = fatfs::Dir<'a, SeekableDisk, DefaultTimeProvider, LossyOemCpConverter>; + +pub type DirEntry<'a> = fatfs::DirEntry<'a, SeekableDisk, DefaultTimeProvider, LossyOemCpConverter>; + +pub type File<'a> = fatfs::File<'a, SeekableDisk, DefaultTimeProvider, LossyOemCpConverter>; diff --git a/modules/axfs/src/fs/fat/file.rs b/modules/axfs/src/fs/fat/file.rs new file mode 100644 index 0000000000..5e6381936a --- /dev/null +++ b/modules/axfs/src/fs/fat/file.rs @@ -0,0 +1,171 @@ +use alloc::{sync::Arc, vec}; +use core::{any::Any, mem, ops::Deref, task::Context}; + +use axfs_ng_vfs::{ + FileNode, FileNodeOps, FilesystemOps, Metadata, MetadataUpdate, NodeFlags, NodeOps, NodeType, + VfsError, VfsResult, +}; +use axpoll::{IoEvents, Pollable}; +use fatfs::{Read, Seek, SeekFrom, Write}; + +use super::{ + FsRef, ff, + fs::FatFilesystem, + util::{file_metadata, into_vfs_err, update_file_metadata}, +}; +use crate::fs::fat::fs::FatFilesystemInner; + +pub struct FatFileNode { + fs: Arc, + inner: FsRef>, + inode: u64, +} + +impl FatFileNode { + pub fn new(fs: Arc, file: ff::File, inode: u64) -> FileNode { + FileNode::new(Arc::new(Self { + fs, + // SAFETY: FsRef guarantees correct lifetime + inner: FsRef::new(unsafe { mem::transmute::(file) }), + inode, + })) + } +} + +fn grow_file(fs: &FatFilesystemInner, file: &mut ff::File<'static>, len: u64) -> VfsResult<()> { + // rust-fatfs does not support growing files directly. We need to + // pad with zeros manually. + let mut pos = file.seek(SeekFrom::End(0)).map_err(into_vfs_err)?; + let block_size = fs.inner.bytes_per_sector() as usize; + let block = vec![0; block_size]; + + while pos < len { + let write = (block_size - (pos as usize & (block_size - 1))).min((len - pos) as usize); + file.write(&block[0..write]).map_err(into_vfs_err)?; + pos += write as u64; + } + Ok(()) +} + +unsafe impl Send for FatFileNode {} + +unsafe impl Sync for FatFileNode {} + +impl NodeOps for FatFileNode { + fn inode(&self) -> u64 { + self.inode + } + + fn metadata(&self) -> VfsResult { + let fs = self.fs.lock(); + let file = self.inner.borrow(&fs); + Ok(file_metadata(&fs, file, NodeType::RegularFile)) + } + + fn update_metadata(&self, update: MetadataUpdate) -> VfsResult<()> { + // FatFS has no ownership & permission + + let fs = self.fs.lock(); + let file = self.inner.borrow_mut(&fs); + update_file_metadata(file, update); + Ok(()) + } + + fn filesystem(&self) -> &dyn FilesystemOps { + self.fs.deref() + } + + fn len(&self) -> VfsResult { + let fs = self.fs.lock(); + let file = self.inner.borrow(&fs); + Ok(file.size().unwrap_or(0) as u64) + } + + fn sync(&self, _data_only: bool) -> VfsResult<()> { + let fs = self.fs.lock(); + let file = self.inner.borrow_mut(&fs); + file.flush().map_err(into_vfs_err) + } + + fn into_any(self: Arc) -> Arc { + self + } + + fn flags(&self) -> NodeFlags { + NodeFlags::BLOCKING + } +} + +impl FileNodeOps for FatFileNode { + fn read_at(&self, mut buf: &mut [u8], offset: u64) -> VfsResult { + let fs = self.fs.lock(); + let file = self.inner.borrow_mut(&fs); + file.seek(SeekFrom::Start(offset)).map_err(into_vfs_err)?; + + let mut read = 0; + loop { + let n = file.read(buf).map_err(into_vfs_err)?; + if n == 0 { + return Ok(read); + } + read += n; + buf = &mut buf[n..]; + } + } + + fn write_at(&self, mut buf: &[u8], offset: u64) -> VfsResult { + let fs = self.fs.lock(); + let file = self.inner.borrow_mut(&fs); + if offset > file.size().unwrap_or(0) as u64 { + grow_file(&fs, file, offset)?; + } + file.seek(SeekFrom::Start(offset)).map_err(into_vfs_err)?; + + let mut written = 0; + loop { + let n = file.write(buf).map_err(into_vfs_err)?; + if n == 0 { + return Ok(written); + } + written += n; + buf = &buf[n..]; + } + } + + fn append(&self, buf: &[u8]) -> VfsResult<(usize, u64)> { + let fs = self.fs.lock(); + let file = self.inner.borrow_mut(&fs); + file.seek(SeekFrom::End(0)).map_err(into_vfs_err)?; + let written = file.write(buf).map_err(into_vfs_err)?; + Ok((written, file.size().unwrap_or(0) as u64)) + } + + fn set_len(&self, len: u64) -> VfsResult<()> { + let fs = self.fs.lock(); + let file = self.inner.borrow_mut(&fs); + if len <= file.size().unwrap_or(0) as u64 { + file.seek(SeekFrom::Start(len)).map_err(into_vfs_err)?; + file.truncate().map_err(into_vfs_err) + } else { + grow_file(&fs, file, len) + } + } + + fn set_symlink(&self, _target: &str) -> VfsResult<()> { + Err(VfsError::PermissionDenied) + } +} + +impl Pollable for FatFileNode { + fn poll(&self) -> IoEvents { + IoEvents::IN | IoEvents::OUT + } + + fn register(&self, _context: &mut Context<'_>, _events: IoEvents) {} +} + +impl Drop for FatFileNode { + fn drop(&mut self) { + self.fs.lock().release_inode(self.inode); + } +} diff --git a/modules/axfs/src/fs/fat/fs.rs b/modules/axfs/src/fs/fat/fs.rs new file mode 100644 index 0000000000..167f58fe6c --- /dev/null +++ b/modules/axfs/src/fs/fat/fs.rs @@ -0,0 +1,98 @@ +use alloc::sync::Arc; +use core::marker::PhantomPinned; + +use axdriver::AxBlockDevice; +use axfs_ng_vfs::{ + DirEntry, Filesystem, FilesystemOps, Reference, StatFs, VfsResult, path::MAX_NAME_LEN, +}; +use kspin::{SpinNoPreempt as Mutex, SpinNoPreemptGuard as MutexGuard}; +use slab::Slab; + +use super::{dir::FatDirNode, ff, util::into_vfs_err}; +use crate::disk::SeekableDisk; + +pub struct FatFilesystemInner { + pub inner: ff::FileSystem, + inode_allocator: Slab<()>, + _pinned: PhantomPinned, +} + +impl FatFilesystemInner { + pub(crate) fn alloc_inode(&mut self) -> u64 { + self.inode_allocator.insert(()) as u64 + 1 + } + + pub(crate) fn release_inode(&mut self, ino: u64) { + self.inode_allocator.remove(ino as usize - 1); + } +} + +pub struct FatFilesystem { + inner: Mutex, + root_dir: Mutex>, +} + +impl FatFilesystem { + pub fn new(dev: AxBlockDevice) -> Filesystem { + let mut inner = FatFilesystemInner { + inner: ff::FileSystem::new(SeekableDisk::new(dev), fatfs::FsOptions::new()) + .expect("failed to initialize FAT filesystem"), + inode_allocator: Slab::new(), + _pinned: PhantomPinned, + }; + let root_inode = inner.alloc_inode(); + let result = Arc::new(Self { + inner: Mutex::new(inner), + root_dir: Mutex::default(), + }); + + let root_dir = DirEntry::new_dir( + |this| { + FatDirNode::new( + result.clone(), + result.lock().inner.root_dir(), + root_inode, + this, + ) + }, + Reference::root(), + ); + *result.root_dir.lock() = Some(root_dir); + Filesystem::new(result) + } +} + +impl FatFilesystem { + pub(crate) fn lock(&self) -> MutexGuard { + self.inner.lock() + } +} + +impl FilesystemOps for FatFilesystem { + fn name(&self) -> &str { + "vfat" + } + + fn root_dir(&self) -> DirEntry { + self.root_dir.lock().clone().unwrap() + } + + fn stat(&self) -> VfsResult { + let fs = self.inner.lock(); + let stats = fs.inner.stats().map_err(into_vfs_err)?; + Ok(StatFs { + fs_type: 0x65735546, // fuse + block_size: stats.cluster_size() as _, + blocks: stats.total_clusters() as _, + blocks_free: stats.free_clusters() as _, + blocks_available: stats.free_clusters() as _, + + file_count: 0, + free_file_count: 0, + + name_length: MAX_NAME_LEN as _, + fragment_size: 0, + mount_flags: 0, + }) + } +} diff --git a/modules/axfs/src/fs/fat/mod.rs b/modules/axfs/src/fs/fat/mod.rs new file mode 100644 index 0000000000..dedbc59672 --- /dev/null +++ b/modules/axfs/src/fs/fat/mod.rs @@ -0,0 +1,71 @@ +mod dir; +mod ff; +mod file; +mod fs; +mod util; + +use core::cell::UnsafeCell; + +use fatfs::SeekFrom; +pub use fs::FatFilesystem; +use fs::FatFilesystemInner; + +use crate::disk::SeekableDisk; + +impl fatfs::IoBase for SeekableDisk { + type Error = (); +} + +impl fatfs::Read for SeekableDisk { + fn read(&mut self, buf: &mut [u8]) -> Result { + SeekableDisk::read(self, buf).map_err(|_| ()) + } +} + +impl fatfs::Write for SeekableDisk { + fn write(&mut self, buf: &[u8]) -> Result { + SeekableDisk::write(self, buf).map_err(|_| ()) + } + + fn flush(&mut self) -> Result<(), Self::Error> { + SeekableDisk::flush(self).map_err(|_| ()) + } +} + +impl fatfs::Seek for SeekableDisk { + fn seek(&mut self, pos: SeekFrom) -> Result { + let size = self.size(); + let new_pos = match pos { + SeekFrom::Start(pos) => Some(pos), + SeekFrom::Current(off) => self.position().checked_add_signed(off), + SeekFrom::End(off) => size.checked_add_signed(off), + } + .ok_or(())?; + self.set_position(new_pos).map_err(|_| ())?; + Ok(new_pos) + } +} + +/// A reference to an object within a filesystem. +pub(crate) struct FsRef { + inner: UnsafeCell, +} + +impl FsRef { + pub fn new(inner: T) -> Self { + Self { + inner: UnsafeCell::new(inner), + } + } + + pub fn borrow<'a>(&self, _fs: &'a FatFilesystemInner) -> &'a T { + // SAFETY: The filesystem outlives the reference + unsafe { &*self.inner.get() } + } + + #[allow(clippy::mut_from_ref)] + pub fn borrow_mut<'a>(&self, _fs: &'a FatFilesystemInner) -> &'a mut T { + // SAFETY: The filesystem outlives the reference + unsafe { &mut *self.inner.get() } + } +} diff --git a/modules/axfs/src/fs/fat/util.rs b/modules/axfs/src/fs/fat/util.rs new file mode 100644 index 0000000000..ec6ae6b744 --- /dev/null +++ b/modules/axfs/src/fs/fat/util.rs @@ -0,0 +1,130 @@ +use alloc::string::String; +use core::time::Duration; + +use axfs_ng_vfs::{DeviceId, Metadata, MetadataUpdate, NodePermission, NodeType, VfsError}; +use chrono::{DateTime, Datelike, NaiveDate, TimeZone, Timelike, Utc}; + +use super::{ff, fs::FatFilesystemInner}; + +#[derive(Clone)] +pub struct CaseInsensitiveString(pub String); + +impl PartialEq for CaseInsensitiveString { + fn eq(&self, other: &Self) -> bool { + self.0 + .bytes() + .map(|c| c.to_ascii_lowercase()) + .eq(other.0.bytes().map(|c| c.to_ascii_lowercase())) + } +} + +impl Eq for CaseInsensitiveString {} + +impl PartialOrd for CaseInsensitiveString { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for CaseInsensitiveString { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.0 + .bytes() + .map(|c| c.to_ascii_lowercase()) + .cmp(other.0.bytes().map(|c| c.to_ascii_lowercase())) + } +} + +pub fn dos_to_unix(date: fatfs::DateTime) -> Duration { + // let date: NaiveDateTime = date.into(); + let date = NaiveDate::from_ymd_opt( + date.date.year as _, + date.date.month as _, + date.date.day as _, + ) + .unwrap() + .and_hms_milli_opt( + date.time.hour as _, + date.time.min as _, + date.time.sec as _, + date.time.millis as _, + ) + .unwrap(); + let Some(datetime) = Utc.from_local_datetime(&date).single() else { + return Duration::default(); + }; + datetime + .signed_duration_since(DateTime::UNIX_EPOCH) + .to_std() + .unwrap_or_default() +} + +pub fn unix_to_dos(datetime: Duration) -> fatfs::DateTime { + let dt = DateTime::UNIX_EPOCH + datetime; + let dt = dt.naive_local(); + + fatfs::DateTime::new( + fatfs::Date::new(dt.year() as _, dt.month() as _, dt.day() as _), + fatfs::Time::new( + dt.hour() as _, + dt.minute() as _, + dt.second() as _, + dt.and_utc().timestamp_subsec_millis() as _, + ), + ) +} + +pub fn file_metadata(fs: &FatFilesystemInner, file: &ff::File, node_type: NodeType) -> Metadata { + let size = file.size().unwrap_or(0) as u64; + let block_size = fs.inner.bytes_per_sector(); + Metadata { + // TODO: inode + inode: 1, + device: 0, + nlink: 1, + mode: NodePermission::default(), + node_type, + uid: 0, + gid: 0, + size, + block_size: block_size as _, + // TODO: The correct block count should be obtained from + // `file.extents()`. However it would be costly. This implementation + // would be enough for now. + blocks: size / block_size as u64, + rdev: DeviceId::default(), + atime: dos_to_unix(fatfs::DateTime::new( + file.accessed(), + fatfs::Time::new(0, 0, 0, 0), + )), + mtime: dos_to_unix(file.modified()), + ctime: dos_to_unix(file.created()), + } +} + +pub fn update_file_metadata(file: &mut ff::File, update: MetadataUpdate) { + if let Some(atime) = update.atime { + #[allow(deprecated)] + file.set_accessed(unix_to_dos(atime).date); + } + if let Some(mtime) = update.mtime { + #[allow(deprecated)] + file.set_modified(unix_to_dos(mtime)); + } +} + +pub fn into_vfs_err(err: fatfs::Error) -> VfsError { + use fatfs::Error::*; + match err { + AlreadyExists => VfsError::AlreadyExists, + CorruptedFileSystem => VfsError::InvalidData, + DirectoryIsNotEmpty => VfsError::DirectoryNotEmpty, + InvalidFileNameLength => VfsError::NameTooLong, + InvalidInput => VfsError::InvalidInput, + UnsupportedFileNameCharacter => VfsError::InvalidData, + NotEnoughSpace => VfsError::StorageFull, + NotFound => VfsError::NotFound, + UnexpectedEof | WriteZero => VfsError::Io, + _ => VfsError::Io, + } +} diff --git a/modules/axfs/src/fs/fatfs.rs b/modules/axfs/src/fs/fatfs.rs deleted file mode 100644 index 549568ec9a..0000000000 --- a/modules/axfs/src/fs/fatfs.rs +++ /dev/null @@ -1,298 +0,0 @@ -use alloc::sync::Arc; -use core::cell::UnsafeCell; - -use axfs_vfs::{VfsDirEntry, VfsError, VfsNodePerm, VfsResult}; -use axfs_vfs::{VfsNodeAttr, VfsNodeOps, VfsNodeRef, VfsNodeType, VfsOps}; -use axsync::Mutex; -use fatfs::{Dir, File, LossyOemCpConverter, NullTimeProvider, Read, Seek, SeekFrom, Write}; - -use crate::dev::Disk; - -const BLOCK_SIZE: usize = 512; - -pub struct FatFileSystem { - inner: fatfs::FileSystem, - root_dir: UnsafeCell>, -} - -pub struct FileWrapper<'a>(Mutex>); -pub struct DirWrapper<'a>(Dir<'a, Disk, NullTimeProvider, LossyOemCpConverter>); - -unsafe impl Sync for FatFileSystem {} -unsafe impl Send for FatFileSystem {} -unsafe impl Send for FileWrapper<'_> {} -unsafe impl Sync for FileWrapper<'_> {} -unsafe impl Send for DirWrapper<'_> {} -unsafe impl Sync for DirWrapper<'_> {} - -impl FatFileSystem { - #[cfg(feature = "use-ramdisk")] - pub fn new(mut disk: Disk) -> Self { - let opts = fatfs::FormatVolumeOptions::new(); - fatfs::format_volume(&mut disk, opts).expect("failed to format volume"); - let inner = fatfs::FileSystem::new(disk, fatfs::FsOptions::new()) - .expect("failed to initialize FAT filesystem"); - Self { - inner, - root_dir: UnsafeCell::new(None), - } - } - - #[cfg(not(feature = "use-ramdisk"))] - pub fn new(disk: Disk) -> Self { - let inner = fatfs::FileSystem::new(disk, fatfs::FsOptions::new()) - .expect("failed to initialize FAT filesystem"); - Self { - inner, - root_dir: UnsafeCell::new(None), - } - } - - pub fn init(&'static self) { - // must be called before later operations - unsafe { *self.root_dir.get() = Some(Self::new_dir(self.inner.root_dir())) } - } - - fn new_file(file: File<'_, Disk, NullTimeProvider, LossyOemCpConverter>) -> Arc { - Arc::new(FileWrapper(Mutex::new(file))) - } - - fn new_dir(dir: Dir<'_, Disk, NullTimeProvider, LossyOemCpConverter>) -> Arc { - Arc::new(DirWrapper(dir)) - } -} - -impl VfsNodeOps for FileWrapper<'static> { - axfs_vfs::impl_vfs_non_dir_default! {} - - fn get_attr(&self) -> VfsResult { - let size = self.0.lock().seek(SeekFrom::End(0)).map_err(as_vfs_err)?; - let blocks = size.div_ceil(BLOCK_SIZE as u64); - // FAT fs doesn't support permissions, we just set everything to 755 - let perm = VfsNodePerm::from_bits_truncate(0o755); - Ok(VfsNodeAttr::new(perm, VfsNodeType::File, size, blocks)) - } - - fn read_at(&self, offset: u64, buf: &mut [u8]) -> VfsResult { - let mut file = self.0.lock(); - file.seek(SeekFrom::Start(offset)).map_err(as_vfs_err)?; // TODO: more efficient - file.read(buf).map_err(as_vfs_err) - } - - fn write_at(&self, offset: u64, buf: &[u8]) -> VfsResult { - let mut file = self.0.lock(); - file.seek(SeekFrom::Start(offset)).map_err(as_vfs_err)?; // TODO: more efficient - file.write(buf).map_err(as_vfs_err) - } - - fn truncate(&self, size: u64) -> VfsResult { - let mut file = self.0.lock(); - let current_size = file.seek(SeekFrom::End(0)).map_err(as_vfs_err)?; - - if size <= current_size { - // If the target size is smaller than the current size, - // perform a standard truncation operation - file.seek(SeekFrom::Start(size)).map_err(as_vfs_err)?; // TODO: more efficient - file.truncate().map_err(as_vfs_err) - } else { - // Calculate the number of bytes to fill - let mut zeros_needed = size - current_size; - // Create a buffer of zeros - let zeros = [0u8; 4096]; - while zeros_needed > 0 { - let to_write = core::cmp::min(zeros_needed, zeros.len() as u64); - let write_buf = &zeros[..to_write as usize]; - file.write(write_buf).map_err(as_vfs_err)?; - zeros_needed -= to_write; - } - Ok(()) - } - } -} - -impl VfsNodeOps for DirWrapper<'static> { - axfs_vfs::impl_vfs_dir_default! {} - - fn get_attr(&self) -> VfsResult { - // FAT fs doesn't support permissions, we just set everything to 755 - Ok(VfsNodeAttr::new( - VfsNodePerm::from_bits_truncate(0o755), - VfsNodeType::Dir, - BLOCK_SIZE as u64, - 1, - )) - } - - fn parent(&self) -> Option { - self.0 - .open_dir("..") - .map_or(None, |dir| Some(FatFileSystem::new_dir(dir))) - } - - fn lookup(self: Arc, path: &str) -> VfsResult { - debug!("lookup at fatfs: {path}"); - let path = path.trim_matches('/'); - if path.is_empty() || path == "." { - return Ok(self.clone()); - } - if let Some(rest) = path.strip_prefix("./") { - return self.lookup(rest); - } - - // TODO: use `fatfs::Dir::find_entry`, but it's not public. - if let Ok(file) = self.0.open_file(path) { - Ok(FatFileSystem::new_file(file)) - } else if let Ok(dir) = self.0.open_dir(path) { - Ok(FatFileSystem::new_dir(dir)) - } else { - Err(VfsError::NotFound) - } - } - - fn create(&self, path: &str, ty: VfsNodeType) -> VfsResult { - debug!("create {ty:?} at fatfs: {path}"); - let path = path.trim_matches('/'); - if path.is_empty() || path == "." { - return Ok(()); - } - if let Some(rest) = path.strip_prefix("./") { - return self.create(rest, ty); - } - - match ty { - VfsNodeType::File => { - self.0.create_file(path).map_err(as_vfs_err)?; - Ok(()) - } - VfsNodeType::Dir => { - self.0.create_dir(path).map_err(as_vfs_err)?; - Ok(()) - } - _ => Err(VfsError::Unsupported), - } - } - - fn remove(&self, path: &str) -> VfsResult { - debug!("remove at fatfs: {path}"); - let path = path.trim_matches('/'); - assert!(!path.is_empty()); // already check at `root.rs` - if let Some(rest) = path.strip_prefix("./") { - return self.remove(rest); - } - self.0.remove(path).map_err(as_vfs_err) - } - - fn read_dir(&self, start_idx: usize, dirents: &mut [VfsDirEntry]) -> VfsResult { - let mut iter = self.0.iter().skip(start_idx); - for (i, out_entry) in dirents.iter_mut().enumerate() { - let x = iter.next(); - match x { - Some(Ok(entry)) => { - let ty = if entry.is_dir() { - VfsNodeType::Dir - } else if entry.is_file() { - VfsNodeType::File - } else { - unreachable!() - }; - *out_entry = VfsDirEntry::new(&entry.file_name(), ty); - } - _ => return Ok(i), - } - } - Ok(dirents.len()) - } - - fn rename(&self, src_path: &str, dst_path: &str) -> VfsResult { - // `src_path` and `dst_path` should in the same mounted fs - debug!("rename at fatfs, src_path: {src_path}, dst_path: {dst_path}"); - - self.0 - .rename(src_path, &self.0, dst_path) - .map_err(as_vfs_err) - } -} - -impl VfsOps for FatFileSystem { - fn root_dir(&self) -> VfsNodeRef { - let root_dir = unsafe { (*self.root_dir.get()).as_ref().unwrap() }; - root_dir.clone() - } -} - -impl fatfs::IoBase for Disk { - type Error = (); -} - -impl Read for Disk { - fn read(&mut self, mut buf: &mut [u8]) -> Result { - let mut read_len = 0; - while !buf.is_empty() { - match self.read_one(buf) { - Ok(0) => break, - Ok(n) => { - let tmp = buf; - buf = &mut tmp[n..]; - read_len += n; - } - Err(_) => return Err(()), - } - } - Ok(read_len) - } -} - -impl Write for Disk { - fn write(&mut self, mut buf: &[u8]) -> Result { - let mut write_len = 0; - while !buf.is_empty() { - match self.write_one(buf) { - Ok(0) => break, - Ok(n) => { - buf = &buf[n..]; - write_len += n; - } - Err(_) => return Err(()), - } - } - Ok(write_len) - } - fn flush(&mut self) -> Result<(), Self::Error> { - Ok(()) - } -} - -impl Seek for Disk { - fn seek(&mut self, pos: SeekFrom) -> Result { - let size = self.size(); - let new_pos = match pos { - SeekFrom::Start(pos) => Some(pos), - SeekFrom::Current(off) => self.position().checked_add_signed(off), - SeekFrom::End(off) => size.checked_add_signed(off), - } - .ok_or(())?; - if new_pos > size { - warn!("Seek beyond the end of the block device"); - } - self.set_position(new_pos); - Ok(new_pos) - } -} - -const fn as_vfs_err(err: fatfs::Error<()>) -> VfsError { - use fatfs::Error::*; - match err { - AlreadyExists => VfsError::AlreadyExists, - CorruptedFileSystem => VfsError::InvalidData, - DirectoryIsNotEmpty => VfsError::DirectoryNotEmpty, - InvalidInput | InvalidFileNameLength | UnsupportedFileNameCharacter => { - VfsError::InvalidInput - } - NotEnoughSpace => VfsError::StorageFull, - NotFound => VfsError::NotFound, - UnexpectedEof => VfsError::UnexpectedEof, - WriteZero => VfsError::WriteZero, - Io(_) => VfsError::Io, - _ => VfsError::Io, - } -} diff --git a/modules/axfs/src/fs/mod.rs b/modules/axfs/src/fs/mod.rs index 8636249d56..5eb69733da 100644 --- a/modules/axfs/src/fs/mod.rs +++ b/modules/axfs/src/fs/mod.rs @@ -1,15 +1,21 @@ -cfg_if::cfg_if! { - if #[cfg(feature = "myfs")] { - pub mod myfs; - } else if #[cfg(feature = "ext4fs")] { - pub mod ext4fs; - } else if #[cfg(feature = "fatfs")] { - pub mod fatfs; - } -} +#[cfg(feature = "fat")] +mod fat; + +#[cfg(feature = "ext4")] +mod ext4; -#[cfg(feature = "devfs")] -pub use axfs_devfs as devfs; +use axdriver::AxBlockDevice; +use axfs_ng_vfs::{Filesystem, VfsResult}; +use cfg_if::cfg_if; -#[cfg(feature = "ramfs")] -pub use axfs_ramfs as ramfs; +pub fn new_default(dev: AxBlockDevice) -> VfsResult { + cfg_if! { + if #[cfg(feature = "ext4")] { + ext4::Ext4Filesystem::new(dev) + } else if #[cfg(feature = "fat")] { + Ok(fat::FatFilesystem::new(dev)) + } else { + panic!("No filesystem feature enabled"); + } + } +} diff --git a/modules/axfs/src/fs/myfs.rs b/modules/axfs/src/fs/myfs.rs deleted file mode 100644 index ca24fd5c78..0000000000 --- a/modules/axfs/src/fs/myfs.rs +++ /dev/null @@ -1,16 +0,0 @@ -use crate::dev::Disk; -use alloc::sync::Arc; -use axfs_vfs::VfsOps; - -/// The interface to define custom filesystems in user apps. -#[crate_interface::def_interface] -pub trait MyFileSystemIf { - /// Creates a new instance of the filesystem with initialization. - /// - /// TODO: use generic disk type - fn new_myfs(disk: Disk) -> Arc; -} - -pub(crate) fn new_myfs(disk: Disk) -> Arc { - crate_interface::call_interface!(MyFileSystemIf::new_myfs(disk)) -} diff --git a/modules/axfs/src/highlevel/file.rs b/modules/axfs/src/highlevel/file.rs new file mode 100644 index 0000000000..c84a72e5e0 --- /dev/null +++ b/modules/axfs/src/highlevel/file.rs @@ -0,0 +1,944 @@ +use alloc::{ + boxed::Box, + sync::{Arc, Weak}, + vec::Vec, +}; +#[cfg(feature = "times")] +use core::sync::atomic::{AtomicU8, Ordering}; +use core::{num::NonZeroUsize, ops::Range, task::Context}; + +use axalloc::{UsageKind, global_allocator}; +use axfs_ng_vfs::{ + FileNode, Location, NodeFlags, NodePermission, NodeType, VfsError, VfsResult, path::Path, +}; +use axhal::mem::{PhysAddr, VirtAddr, virt_to_phys}; +use axio::{Buf, BufMut, SeekFrom}; +use axpoll::{IoEvents, Pollable}; +use intrusive_collections::{LinkedList, LinkedListAtomicLink, intrusive_adapter}; +use lru::LruCache; +use spin::{Mutex, RwLock}; + +use super::FsContext; + +bitflags::bitflags! { + #[derive(Debug, Clone, Copy)] + pub struct FileFlags: u8 { + const READ = 1; + const WRITE = 2; + const EXECUTE = 4; + const APPEND = 8; + const PATH = 16; + } +} + +/// Results returned by [`OpenOptions::open`]. +pub enum OpenResult { + File(File), + Dir(Location), +} + +impl OpenResult { + pub fn into_file(self) -> VfsResult { + match self { + Self::File(file) => Ok(file), + Self::Dir(_) => Err(VfsError::IsADirectory), + } + } + + pub fn into_dir(self) -> VfsResult { + match self { + Self::Dir(dir) => Ok(dir), + Self::File(_) => Err(VfsError::NotADirectory), + } + } + + pub fn into_location(self) -> Location { + match self { + Self::File(file) => file.location().clone(), + Self::Dir(dir) => dir, + } + } +} + +/// Options and flags which can be used to configure how a file is opened. +#[derive(Debug, Clone)] +pub struct OpenOptions { + // generic + read: bool, + write: bool, + append: bool, + truncate: bool, + create: bool, + create_new: bool, + directory: bool, + no_follow: bool, + direct: bool, + user: Option<(u32, u32)>, + path: bool, + node_type: NodeType, + // system-specific + mode: u32, +} + +impl OpenOptions { + /// Creates a blank new set of options ready for configuration. + pub fn new() -> Self { + Self { + // generic + read: false, + write: false, + append: false, + truncate: false, + create: false, + create_new: false, + directory: false, + no_follow: false, + direct: false, + user: None, + path: false, + node_type: NodeType::RegularFile, + // system-specific + mode: 0o666, + } + } + + /// Sets the option for read access. + pub fn read(&mut self, read: bool) -> &mut Self { + self.read = read; + self + } + + /// Sets the option for write access. + pub fn write(&mut self, write: bool) -> &mut Self { + self.write = write; + self + } + + /// Sets the option for the append mode. + pub fn append(&mut self, append: bool) -> &mut Self { + self.append = append; + self + } + + /// Sets the option for truncating a previous file. + pub fn truncate(&mut self, truncate: bool) -> &mut Self { + self.truncate = truncate; + self + } + + /// Sets the option to create a new file, or open it if it already exists. + pub fn create(&mut self, create: bool) -> &mut Self { + self.create = create; + self + } + + /// Sets the option to create a new file, failing if it already exists. + pub fn create_new(&mut self, create_new: bool) -> &mut Self { + self.create_new = create_new; + self + } + + /// Sets the option to open directory instead. + pub fn directory(&mut self, directory: bool) -> &mut Self { + self.directory = directory; + self + } + + /// Sets the option to not follow symlinks. + pub fn no_follow(&mut self, no_follow: bool) -> &mut Self { + self.no_follow = no_follow; + self + } + + /// Sets the option to open the file with direct I/O.\ + pub fn direct(&mut self, direct: bool) -> &mut Self { + self.direct = direct; + self + } + + /// Sets the user and group id to open the file with. + pub fn user(&mut self, uid: u32, gid: u32) -> &mut Self { + self.user = Some((uid, gid)); + self + } + + /// Sets the option for path only access. + pub fn path(&mut self, path: bool) -> &mut Self { + self.path = path; + self + } + + /// Sets the node type for the file. + /// + /// This will only be used if the file is created. + pub fn node_type(&mut self, node_type: NodeType) -> &mut Self { + self.node_type = node_type; + self + } + + /// Sets the mode bits that a new file will be created with. + pub fn mode(&mut self, mode: u32) -> &mut Self { + self.mode = mode; + self + } + + fn _open(&self, loc: Location) -> VfsResult { + let flags = self.to_flags()?; + + if self.directory { + if flags.contains(FileFlags::WRITE) { + return Err(VfsError::IsADirectory); + } + loc.check_is_dir()?; + } + if self.truncate { + loc.entry().as_file()?.set_len(0)?; + } + + Ok(if loc.is_dir() { + OpenResult::Dir(loc) + } else { + // TODO(mivik): is this correct? + let non_cacheable_type = matches!( + loc.metadata()?.node_type, + NodeType::CharacterDevice | NodeType::Fifo | NodeType::Socket + ); + + let direct = non_cacheable_type + || self.path + || self.direct + || loc.flags().contains(NodeFlags::NON_CACHEABLE); + let backend = if !direct || loc.flags().contains(NodeFlags::ALWAYS_CACHE) { + FileBackend::new_cached(loc) + } else { + FileBackend::new_direct(loc) + }; + OpenResult::File(File::new(backend, flags)) + }) + } + + pub fn open_loc(&self, loc: Location) -> VfsResult { + if !self.is_valid() { + return Err(VfsError::InvalidInput); + } + self._open(loc) + } + + pub fn open(&self, context: &FsContext, path: impl AsRef) -> VfsResult { + if !self.is_valid() { + return Err(VfsError::InvalidInput); + } + + let loc = match context.resolve_parent(path.as_ref()) { + Ok((parent, name)) => { + let mut loc = parent.open_file( + &name, + &axfs_ng_vfs::OpenOptions { + create: self.create, + create_new: self.create_new, + node_type: self.node_type, + permission: NodePermission::from_bits_truncate(self.mode as _), + user: self.user, + }, + )?; + if !self.no_follow { + loc = context + .with_current_dir(parent)? + .try_resolve_symlink(loc, &mut 0)?; + } + loc + } + Err(VfsError::InvalidInput) => { + // root directory + context.root_dir().clone() + } + Err(err) => return Err(err), + }; + self._open(loc) + } + + pub(crate) fn to_flags(&self) -> VfsResult { + Ok(match (self.read, self.write, self.append) { + (true, false, false) => FileFlags::READ, + (false, true, false) => FileFlags::WRITE, + (true, true, false) => FileFlags::READ | FileFlags::WRITE, + (false, _, true) => FileFlags::WRITE | FileFlags::APPEND, + (true, _, true) => FileFlags::READ | FileFlags::WRITE | FileFlags::APPEND, + (false, false, false) => return Err(VfsError::InvalidInput), + } | if self.path { + FileFlags::PATH + } else { + FileFlags::empty() + }) + } + + pub(crate) fn is_valid(&self) -> bool { + if !self.read && !self.write && !self.append { + return true; + } + match (self.write, self.append) { + (true, false) => {} + (false, false) => { + if self.truncate || self.create || self.create_new { + return false; + } + } + (_, true) => { + if self.truncate && !self.create_new { + return false; + } + } + } + true + } +} + +impl Default for OpenOptions { + fn default() -> Self { + Self::new() + } +} + +const PAGE_SIZE: usize = 4096; + +#[derive(Debug)] +pub struct PageCache { + addr: VirtAddr, + dirty: bool, +} + +impl PageCache { + fn new() -> VfsResult { + let addr = global_allocator() + .alloc_pages(1, PAGE_SIZE, UsageKind::PageCache) + .inspect_err(|err| { + warn!("Failed to allocate page cache: {:?}", err); + })?; + Ok(Self { + addr: addr.into(), + dirty: false, + }) + } + + pub fn paddr(&self) -> PhysAddr { + virt_to_phys(self.addr) + } + + pub fn mark_dirty(&mut self) { + self.dirty = true; + } + + pub fn data(&mut self) -> &mut [u8] { + unsafe { core::slice::from_raw_parts_mut(self.addr.as_mut_ptr(), PAGE_SIZE) } + } +} + +impl Drop for PageCache { + fn drop(&mut self) { + if self.dirty { + warn!("dirty page dropped without flushing"); + } + global_allocator().dealloc_pages(self.addr.as_usize(), 1, UsageKind::PageCache); + } +} + +struct EvictListener { + listener: Box, + link: LinkedListAtomicLink, +} + +intrusive_adapter!(EvictListenerAdapter = Box: EvictListener { link: LinkedListAtomicLink }); + +struct CachedFileShared { + page_cache: Mutex>, + evict_listeners: Mutex>, +} + +impl CachedFileShared { + pub fn new() -> Self { + Self { + page_cache: Mutex::new(LruCache::new(NonZeroUsize::new(64).unwrap())), + evict_listeners: Mutex::new(LinkedList::default()), + } + } + + pub fn new_unbounded() -> Self { + Self { + page_cache: Mutex::new(LruCache::unbounded()), + evict_listeners: Mutex::new(LinkedList::default()), + } + } +} + +pub struct CachedFile { + inner: Location, + shared: Arc, + in_memory: bool, + /// Only one thread can append to the file at a time, while multiple writers + /// are permitted. + append_lock: RwLock<()>, +} + +impl Clone for CachedFile { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + shared: self.shared.clone(), + in_memory: self.in_memory, + append_lock: RwLock::new(()), + } + } +} + +enum FileUserData { + Weak(Weak), + Strong(Arc), +} + +impl FileUserData { + pub fn get(&self) -> Option> { + match self { + FileUserData::Weak(weak) => weak.upgrade(), + FileUserData::Strong(strong) => Some(strong.clone()), + } + } +} + +impl CachedFile { + pub fn get_or_create(location: Location) -> Self { + let in_memory = location.filesystem().name() == "tmpfs"; + + let mut guard = location.user_data(); + let shared = if let Some(shared) = guard.get::().and_then(|it| it.get()) { + shared + } else { + let (shared, user_data) = if in_memory { + let shared = Arc::new(CachedFileShared::new_unbounded()); + (shared.clone(), FileUserData::Strong(shared)) + } else { + let shared = Arc::new(CachedFileShared::new()); + let user_data = FileUserData::Weak(Arc::downgrade(&shared)); + (shared, user_data) + }; + guard.insert(user_data); + shared + }; + drop(guard); + + Self { + inner: location, + shared, + in_memory, + append_lock: RwLock::new(()), + } + } + + pub fn ptr_eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.shared, &other.shared) + } + + pub fn in_memory(&self) -> bool { + self.in_memory + } + + pub fn add_evict_listener(&self, listener: F) -> usize + where + F: Fn(u32, &PageCache) + Send + Sync + 'static, + { + let pointer = Box::new(EvictListener { + listener: Box::new(listener), + link: LinkedListAtomicLink::new(), + }); + let handle = pointer.as_ref() as *const EvictListener as usize; + self.shared.evict_listeners.lock().push_back(pointer); + handle + } + + pub unsafe fn remove_evict_listener(&self, handle: usize) { + let mut guard = self.shared.evict_listeners.lock(); + let mut cursor = unsafe { guard.cursor_mut_from_ptr(handle as *const EvictListener) }; + cursor.remove(); + } + + fn evict_cache(&self, file: &FileNode, pn: u32, page: &mut PageCache) -> VfsResult<()> { + for listener in self.shared.evict_listeners.lock().iter() { + (listener.listener)(pn, &page); + } + if page.dirty { + let page_start = pn as u64 * PAGE_SIZE as u64; + let len = (file.len()? - page_start).min(PAGE_SIZE as u64) as usize; + file.write_at(&page.data()[..len], page_start)?; + page.dirty = false; + } + Ok(()) + } + + fn page_or_insert<'a>( + &self, + file: &FileNode, + cache: &'a mut LruCache, + pn: u32, + ) -> VfsResult<(&'a mut PageCache, Option<(u32, PageCache)>)> { + // TODO: Matching the result of `get_mut` confuses compiler. See + // https://users.rust-lang.org/t/return-do-not-release-mutable-borrow/55757. + if cache.contains(&pn) { + return Ok((cache.get_mut(&pn).unwrap(), None)); + } + let mut evicted = None; + if cache.len() == cache.cap().get() { + // Cache is full, remove the least recently used page + if let Some((pn, mut page)) = cache.pop_lru() { + self.evict_cache(file, pn, &mut page)?; + evicted = Some((pn, page)); + } + } + + // Page not in cache, read it + let mut page = PageCache::new()?; + if self.in_memory { + page.data().fill(0); + } else { + file.read_at(page.data(), pn as u64 * PAGE_SIZE as u64)?; + } + cache.put(pn, page); + Ok((cache.get_mut(&pn).unwrap(), evicted)) + } + + pub fn with_page(&self, pn: u32, f: impl FnOnce(Option<&mut PageCache>) -> R) -> R { + f(self.shared.page_cache.lock().get_mut(&pn)) + } + + pub fn with_page_or_insert( + &self, + pn: u32, + f: impl FnOnce(&mut PageCache, Option<(u32, PageCache)>) -> VfsResult, + ) -> VfsResult { + let mut guard = self.shared.page_cache.lock(); + let (page, evicted) = self.page_or_insert(self.inner.entry().as_file()?, &mut guard, pn)?; + f(page, evicted) + } + + fn with_pages( + &self, + range: Range, + page_initial: impl FnOnce(&FileNode) -> VfsResult, + mut page_each: impl FnMut(T, &mut PageCache, Range) -> VfsResult, + ) -> VfsResult { + let file = self.inner.entry().as_file()?; + let mut initial = page_initial(file)?; + let start_page = (range.start / PAGE_SIZE as u64) as u32; + let end_page = range.end.div_ceil(PAGE_SIZE as u64) as u32; + let mut page_offset = (range.start % PAGE_SIZE as u64) as usize; + for pn in start_page..end_page { + let page_start = pn as u64 * PAGE_SIZE as u64; + + let mut guard = self.shared.page_cache.lock(); + let page = self.page_or_insert(file, &mut guard, pn)?.0; + + initial = page_each( + initial, + page, + page_offset..(range.end - page_start).min(PAGE_SIZE as u64) as usize, + )?; + page_offset = 0; + } + + Ok(initial) + } + + pub fn read_at(&self, dst: &mut impl BufMut, offset: u64) -> VfsResult { + let len = self.inner.len()?; + let end = (offset + dst.remaining_mut() as u64).min(len); + if end <= offset { + return Ok(0); + } + self.with_pages( + offset..end, + |_| Ok(0), + |read, page, range| { + let len = range.end - range.start; + dst.write(&page.data()[range.start..range.end])?; + Ok(read + len) + }, + ) + } + + fn write_at_locked(&self, buf: &mut impl Buf, offset: u64) -> VfsResult { + let end = offset + buf.remaining() as u64; + self.with_pages( + offset..end, + |file| { + if end > file.len()? { + file.set_len(end)?; + } + Ok(0) + }, + |written, page, range| { + let len = range.end - range.start; + buf.read(&mut page.data()[range.start..range.end])?; + if !self.in_memory { + page.dirty = true; + } + Ok(written + len) + }, + ) + } + + pub fn write_at(&self, buf: &mut impl Buf, offset: u64) -> VfsResult { + let _guard = self.append_lock.read(); + self.write_at_locked(buf, offset) + } + + pub fn append(&self, buf: &mut impl Buf) -> VfsResult<(usize, u64)> { + let _guard = self.append_lock.write(); + let file = self.inner.entry().as_file()?; + let len = file.len()?; + self.write_at_locked(buf, len) + .map(|written| (written, len + written as u64)) + } + + pub fn set_len(&self, len: u64) -> VfsResult<()> { + let file = self.inner.entry().as_file()?; + let old_len = file.len()?; + file.set_len(len)?; + + let old_last_page = (old_len / PAGE_SIZE as u64) as u32; + let new_last_page = (len / PAGE_SIZE as u64) as u32; + if old_len < len { + let mut guard = self.shared.page_cache.lock(); + if let Some(page) = guard.get_mut(&old_last_page) { + let page_start = old_last_page as u64 * PAGE_SIZE as u64; + let old_page_offset = (old_len - page_start) as usize; + let new_page_offset = (len - page_start).min(PAGE_SIZE as u64) as usize; + page.data()[old_page_offset..new_page_offset].fill(0); + } + } else if old_last_page > new_last_page { + // For truncating, we need to remove all pages that are beyond the + // new length + // TODO(mivik): can this be more efficient? + let mut guard = self.shared.page_cache.lock(); + let keys = guard + .iter() + .map(|(k, _)| *k) + .filter(|it| *it > new_last_page) + .collect::>(); + for pn in keys { + if let Some(mut page) = guard.pop(&pn) { + if !self.in_memory { + // Don't write back pages since they're discarded + page.dirty = false; + self.evict_cache(file, pn, &mut page)?; + } + } + } + } + Ok(()) + } + + pub fn sync(&self, data_only: bool) -> VfsResult<()> { + if self.in_memory { + return Ok(()); + } + let file = self.inner.entry().as_file()?; + let mut guard = self.shared.page_cache.lock(); + while let Some((pn, mut page)) = guard.pop_lru() { + self.evict_cache(file, pn, &mut page)?; + } + file.sync(data_only)?; + Ok(()) + } + + pub fn location(&self) -> &Location { + &self.inner + } +} + +impl Drop for CachedFile { + fn drop(&mut self) { + if Arc::strong_count(&self.shared) > 1 { + // If there are other references to this cached file, we don't + // need to drop it. + return; + } + if let Err(err) = self.sync(false) { + warn!("Failed to sync file on drop: {err:?}"); + } + } +} + +/// Low-level interface for file operations. +#[derive(Clone)] +pub enum FileBackend { + Cached(CachedFile), + Direct(Location), +} + +impl FileBackend { + pub(crate) fn new_direct(location: Location) -> Self { + Self::Direct(location) + } + + pub(crate) fn new_cached(location: Location) -> Self { + Self::Cached(CachedFile::get_or_create(location)) + } + + pub fn read_at(&self, dst: &mut impl BufMut, mut offset: u64) -> VfsResult { + match self { + Self::Cached(cached) => cached.read_at(dst, offset), + Self::Direct(loc) => dst.fill(|buf| { + loc.entry().as_file()?.read_at(buf, offset).inspect(|read| { + offset += *read as u64; + }) + }), + } + } + + pub fn write_at(&self, src: &mut impl Buf, mut offset: u64) -> VfsResult { + match self { + Self::Cached(cached) => cached.write_at(src, offset), + Self::Direct(loc) => src.consume(|buf| { + loc.entry() + .as_file()? + .write_at(buf, offset) + .inspect(|written| { + offset += *written as u64; + }) + }), + } + } + + pub fn append(&self, src: &mut impl Buf) -> VfsResult<(usize, u64)> { + match self { + Self::Cached(cached) => cached.append(src), + Self::Direct(loc) => { + let mut buffer = Box::<[u8]>::new_uninit_slice(src.remaining()); + src.read(unsafe { buffer.assume_init_mut() })?; + loc.entry() + .as_file()? + .append(unsafe { buffer.assume_init_ref() }) + } + } + } + + pub fn location(&self) -> &Location { + match self { + Self::Cached(cached) => cached.location(), + Self::Direct(loc) => loc, + } + } + + pub fn sync(&self, data_only: bool) -> VfsResult<()> { + match self { + Self::Cached(cached) => cached.sync(data_only), + Self::Direct(loc) => loc.entry().as_file()?.sync(data_only), + } + } + + pub fn set_len(&self, len: u64) -> VfsResult<()> { + match self { + Self::Cached(cached) => cached.set_len(len), + Self::Direct(loc) => loc.entry().as_file()?.set_len(len), + } + } +} + +/// Provides `std::fs::File`-like interface. +pub struct File { + inner: FileBackend, + flags: FileFlags, + position: Option>, + #[cfg(feature = "times")] + access_flags: AtomicU8, +} + +impl File { + pub fn new(inner: FileBackend, flags: FileFlags) -> Self { + let position = if inner.location().flags().contains(NodeFlags::STREAM) { + None + } else { + Some(Mutex::new(if flags.contains(FileFlags::APPEND) { + inner.location().len().unwrap_or_default() + } else { + 0 + })) + }; + Self { + inner, + flags, + position, + #[cfg(feature = "times")] + access_flags: AtomicU8::new(0), + } + } + + pub fn open(context: &FsContext, path: impl AsRef) -> VfsResult { + OpenOptions::new() + .read(true) + .open(context, path.as_ref()) + .and_then(OpenResult::into_file) + } + + pub fn create(context: &FsContext, path: impl AsRef) -> VfsResult { + OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(context, path.as_ref()) + .and_then(OpenResult::into_file) + } + + pub fn access(&self, flags: FileFlags) -> VfsResult<&FileBackend> { + if self.flags.contains(flags) && !self.is_path() { + Ok(&self.inner) + } else { + Err(VfsError::BadFileDescriptor) + } + } + + pub fn is_path(&self) -> bool { + self.flags.contains(FileFlags::PATH) + } + + pub fn flags(&self) -> FileFlags { + self.flags + } + + pub fn backend(&self) -> VfsResult<&FileBackend> { + self.access(FileFlags::empty())?; + Ok(&self.inner) + } + + pub fn location(&self) -> &Location { + self.inner.location() + } + + /// Reads a number of bytes starting from a given offset. + pub fn read_at(&self, dst: &mut impl BufMut, offset: u64) -> VfsResult { + self.access(FileFlags::READ)?.read_at(dst, offset) + } + + /// Writes a number of bytes starting from a given offset. + pub fn write_at(&self, src: &mut impl Buf, offset: u64) -> VfsResult { + self.access(FileFlags::WRITE)?.write_at(src, offset) + } + + /// Attempts to sync OS-internal file content and metadata to disk. + /// + /// If `data_only` is `true`, only the file data is synced, not the + /// metadata. + pub fn sync(&self, data_only: bool) -> VfsResult<()> { + self.access(FileFlags::empty())?; + self.inner.sync(data_only) + } + + pub fn read(&self, dst: &mut impl BufMut) -> axio::Result { + #[cfg(feature = "times")] + { + self.access_flags.fetch_or(1, Ordering::AcqRel); + } + if let Some(pos) = self.position.as_ref() { + let mut pos = pos.lock(); + self.read_at(dst, *pos).inspect(|n| { + *pos += *n as u64; + }) + } else { + self.read_at(dst, 0) + } + } + + pub fn write(&self, src: &mut impl Buf) -> axio::Result { + #[cfg(feature = "times")] + { + self.access_flags.fetch_or(3, Ordering::AcqRel); + } + if let Some(pos) = self.position.as_ref() { + let mut pos = pos.lock(); + if let Ok(f) = self.access(FileFlags::APPEND) { + f.append(src).map(|(written, new_size)| { + *pos = new_size; + written + }) + } else { + self.write_at(src, *pos).inspect(|n| { + *pos += *n as u64; + }) + } + } else { + self.write_at(src, 0) + } + } + + pub fn flush(&self) -> axio::Result { + self.access(FileFlags::empty())?; + Ok(()) + } +} + +impl<'a> axio::Read for &'a File { + fn read(&mut self, mut buf: &mut [u8]) -> axio::Result { + (*self).read(&mut buf) + } +} + +impl<'a> axio::Write for &'a File { + fn write(&mut self, mut buf: &[u8]) -> axio::Result { + (*self).write(&mut buf) + } + + fn flush(&mut self) -> axio::Result { + (*self).flush() + } +} + +impl<'a> axio::Seek for &'a File { + fn seek(&mut self, pos: SeekFrom) -> axio::Result { + self.access(FileFlags::empty())?; + + if let Some(guard) = self.position.as_ref() { + let mut guard = guard.lock(); + let new_pos = match pos { + SeekFrom::Start(pos) => pos, + SeekFrom::End(off) => { + let size = self.access(FileFlags::empty())?.location().len()?; + size.checked_add_signed(off).ok_or(VfsError::InvalidInput)? + } + SeekFrom::Current(off) => guard + .checked_add_signed(off) + .ok_or(VfsError::InvalidInput)?, + }; + *guard = new_pos; + Ok(new_pos) + } else { + Ok(0) + } + } +} + +impl Pollable for File { + fn poll(&self) -> IoEvents { + self.inner.location().poll() + } + + fn register(&self, context: &mut Context<'_>, events: IoEvents) { + self.inner.location().register(context, events) + } +} + +#[cfg(feature = "times")] +impl Drop for File { + fn drop(&mut self) { + let flags = self.access_flags.load(Ordering::Acquire); + if flags != 0 { + let mut update = axfs_ng_vfs::MetadataUpdate::default(); + if flags & 1 != 0 { + update.atime = Some(axhal::time::wall_time()); + } + if flags & 2 != 0 { + update.mtime = Some(axhal::time::wall_time()); + } + if let Err(err) = self.inner.location().update_metadata(update) { + warn!("Failed to update file times on drop: {err:?}"); + } + } + } +} diff --git a/modules/axfs/src/highlevel/fs.rs b/modules/axfs/src/highlevel/fs.rs new file mode 100644 index 0000000000..29aef68887 --- /dev/null +++ b/modules/axfs/src/highlevel/fs.rs @@ -0,0 +1,345 @@ +use alloc::{ + borrow::{Cow, ToOwned}, + collections::vec_deque::VecDeque, + string::String, + sync::Arc, + vec::Vec, +}; + +use axfs_ng_vfs::{ + Location, Metadata, NodePermission, NodeType, VfsError, VfsResult, + path::{Component, Components, Path, PathBuf}, +}; +use axio::{Read, Write}; +use axsync::Mutex; +use spin::Once; + +use super::File; + +pub const SYMLINKS_MAX: usize = 40; + +pub static ROOT_FS_CONTEXT: Once = Once::new(); + +scope_local::scope_local! { + pub static FS_CONTEXT: Arc> = + Arc::new(Mutex::new( + ROOT_FS_CONTEXT + .get() + .expect("Root FS context not initialized") + .clone(), + )); +} + +pub struct ReadDirEntry { + pub name: String, + pub ino: u64, + pub node_type: NodeType, + pub offset: u64, +} + +/// Provides `std::fs`-like interface. +#[derive(Debug, Clone)] +pub struct FsContext { + root_dir: Location, + current_dir: Location, +} + +impl FsContext { + pub fn new(root_dir: Location) -> Self { + Self { + root_dir: root_dir.clone(), + current_dir: root_dir, + } + } + + pub fn root_dir(&self) -> &Location { + &self.root_dir + } + + pub fn current_dir(&self) -> &Location { + &self.current_dir + } + + pub fn set_current_dir(&mut self, current_dir: Location) -> VfsResult<()> { + current_dir.check_is_dir()?; + self.current_dir = current_dir; + Ok(()) + } + + pub fn with_current_dir(&self, current_dir: Location) -> VfsResult { + current_dir.check_is_dir()?; + Ok(Self { + root_dir: self.root_dir.clone(), + current_dir, + }) + } + + /// Attempts to resolve a possible symlink, at the current location (this + /// assumes that `loc` is a child of current directory). + pub fn try_resolve_symlink( + &self, + loc: Location, + follow_count: &mut usize, + ) -> VfsResult { + if loc.node_type() != NodeType::Symlink { + return Ok(loc); + } + if *follow_count >= SYMLINKS_MAX { + return Err(VfsError::FilesystemLoop); + } + *follow_count += 1; + let target = loc.read_link()?; + if target.is_empty() { + return Err(VfsError::NotFound); + } + self.resolve_components(PathBuf::from(target).components(), follow_count) + } + + fn lookup(&self, dir: &Location, name: &str, follow_count: &mut usize) -> VfsResult { + let loc = dir.lookup_no_follow(name)?; + self.with_current_dir(dir.clone())? + .try_resolve_symlink(loc, follow_count) + } + + fn resolve_components( + &self, + components: Components, + follow_count: &mut usize, + ) -> VfsResult { + let mut dir = self.current_dir.clone(); + for comp in components { + match comp { + Component::CurDir => {} + Component::ParentDir => { + dir = dir.parent().unwrap_or_else(|| self.root_dir.clone()); + } + Component::RootDir => { + dir = self.root_dir.clone(); + } + Component::Normal(name) => { + dir = self.lookup(&dir, name, follow_count)?; + } + } + } + Ok(dir) + } + + fn resolve_inner<'a>( + &self, + path: &'a Path, + follow_count: &mut usize, + ) -> VfsResult<(Location, Option<&'a str>)> { + let entry_name = path.file_name(); + let mut components = path.components(); + if entry_name.is_some() { + components.next_back(); + } + let dir = self.resolve_components(components, follow_count)?; + dir.check_is_dir()?; + Ok((dir, entry_name)) + } + + /// Resolves a path starting from `current_dir`. + pub fn resolve(&self, path: impl AsRef) -> VfsResult { + let mut follow_count = 0; + let (dir, name) = self.resolve_inner(path.as_ref(), &mut follow_count)?; + match name { + Some(name) => self.lookup(&dir, name, &mut follow_count), + None => Ok(dir), + } + } + + /// Resolves a path starting from `current_dir` not following symlinks. + pub fn resolve_no_follow(&self, path: impl AsRef) -> VfsResult { + let (dir, name) = self.resolve_inner(path.as_ref(), &mut 0)?; + match name { + Some(name) => dir.lookup_no_follow(name), + None => Ok(dir), + } + } + + /// Taking current node as root directory, resolves a path starting from + /// `current_dir`. + /// + /// Returns `(parent_dir, entry_name)`, where `entry_name` is the name of + /// the entry. + pub fn resolve_parent<'a>(&self, path: &'a Path) -> VfsResult<(Location, Cow<'a, str>)> { + let (dir, name) = self.resolve_inner(path, &mut 0)?; + if let Some(name) = name { + Ok((dir, Cow::Borrowed(name))) + } else if let Some(parent) = dir.parent() { + Ok((parent, Cow::Owned(dir.name().to_owned()))) + } else { + Err(VfsError::InvalidInput) + } + } + + /// Resolves a path starting from `current_dir`, returning the parent + /// directory and the name of the entry. + /// + /// This function requires that the entry does not exist and the parent + /// exists. Note that, it does not perform an actual check to ensure the + /// entry's non-existence. It simply raises an error if the entry name is + /// not present in the path. + pub fn resolve_nonexistent<'a>(&self, path: &'a Path) -> VfsResult<(Location, &'a str)> { + let (dir, name) = self.resolve_inner(path, &mut 0)?; + if let Some(name) = name { + Ok((dir, name)) + } else { + Err(VfsError::InvalidInput) + } + } + + /// Retrieves metadata for the file. + pub fn metadata(&self, path: impl AsRef) -> VfsResult { + self.resolve(path)?.metadata() + } + + /// Reads the entire contents of a file into a bytes vector. + pub fn read(&self, path: impl AsRef) -> VfsResult> { + let mut buf = Vec::new(); + let file = File::open(self, path.as_ref())?; + (&file).read_to_end(&mut buf)?; + Ok(buf) + } + + /// Reads the entire contents of a file into a string. + pub fn read_to_string(&self, path: impl AsRef) -> VfsResult { + String::from_utf8(self.read(path)?).map_err(|_| VfsError::InvalidData) + } + + /// Writes a slice as the entire contents of a file. + /// + /// This function will create a file if it does not exist, and will entirely + /// replace its contents if it does. + pub fn write(&self, path: impl AsRef, buf: impl AsRef<[u8]>) -> VfsResult<()> { + let file = File::create(self, path.as_ref())?; + (&file).write_all(buf.as_ref())?; + Ok(()) + } + + /// Returns an iterator over the entries in a directory. + pub fn read_dir(&self, path: impl AsRef) -> VfsResult { + let dir = self.resolve(path)?; + Ok(ReadDir { + dir, + buf: VecDeque::new(), + offset: 0, + ended: false, + }) + } + + /// Removes a file from the filesystem. + pub fn remove_file(&self, path: impl AsRef) -> VfsResult<()> { + let entry = self.resolve_no_follow(path.as_ref())?; + entry + .parent() + .ok_or(VfsError::IsADirectory)? + .unlink(entry.name(), false) + } + + /// Removes a directory from the filesystem. + pub fn remove_dir(&self, path: impl AsRef) -> VfsResult<()> { + let entry = self.resolve_no_follow(path.as_ref())?; + entry + .parent() + .ok_or(VfsError::ResourceBusy)? + .unlink(entry.name(), true) + } + + /// Renames a file or directory to a new name, replacing the original file + /// if `to` already exists. + pub fn rename(&self, from: impl AsRef, to: impl AsRef) -> VfsResult<()> { + let (src_dir, src_name) = self.resolve_parent(from.as_ref())?; + let (dst_dir, dst_name) = self.resolve_parent(to.as_ref())?; + src_dir.rename(&src_name, &dst_dir, &dst_name) + } + + /// Creates a new, empty directory at the provided path. + pub fn create_dir(&self, path: impl AsRef, mode: NodePermission) -> VfsResult { + let (dir, name) = self.resolve_nonexistent(path.as_ref())?; + dir.create(name, NodeType::Directory, mode) + } + + /// Creates a new hard link on the filesystem. + pub fn link( + &self, + old_path: impl AsRef, + new_path: impl AsRef, + ) -> VfsResult { + let old = self.resolve(old_path.as_ref())?; + let (new_dir, new_name) = self.resolve_nonexistent(new_path.as_ref())?; + new_dir.link(new_name, &old) + } + + /// Creates a new symbolic link on the filesystem. + pub fn symlink( + &self, + target: impl AsRef, + link_path: impl AsRef, + ) -> VfsResult { + let (dir, name) = self.resolve_nonexistent(link_path.as_ref())?; + if dir.lookup_no_follow(name).is_ok() { + return Err(VfsError::AlreadyExists); + } + let symlink = dir.create(name, NodeType::Symlink, NodePermission::default())?; + symlink.entry().as_file()?.set_symlink(target.as_ref())?; + Ok(symlink) + } + + /// Returns the canonical, absolute form of a path. + pub fn canonicalize(&self, path: impl AsRef) -> VfsResult { + self.resolve(path.as_ref())?.absolute_path() + } +} + +/// Iterator returned by [`FsContext::read_dir`]. +pub struct ReadDir { + dir: Location, + buf: VecDeque, + offset: u64, + ended: bool, +} + +impl ReadDir { + // TODO: tune this + pub const BUF_SIZE: usize = 128; +} + +impl Iterator for ReadDir { + type Item = VfsResult; + + fn next(&mut self) -> Option { + if self.ended { + return None; + } + + if self.buf.is_empty() { + self.buf.clear(); + let result = self.dir.read_dir( + self.offset, + &mut |name: &str, ino: u64, node_type: NodeType, offset: u64| { + self.buf.push_back(ReadDirEntry { + name: name.to_owned(), + ino, + node_type, + offset, + }); + self.offset = offset; + self.buf.len() < Self::BUF_SIZE + }, + ); + + // We handle errors only if we didn't get any entries + if self.buf.is_empty() { + if let Err(err) = result { + return Some(Err(err)); + } + self.ended = true; + return None; + } + } + + self.buf.pop_front().map(Ok) + } +} diff --git a/modules/axfs/src/highlevel/mod.rs b/modules/axfs/src/highlevel/mod.rs new file mode 100644 index 0000000000..e2e6ac0a76 --- /dev/null +++ b/modules/axfs/src/highlevel/mod.rs @@ -0,0 +1,5 @@ +mod file; +mod fs; + +pub use file::*; +pub use fs::*; diff --git a/modules/axfs/src/lib.rs b/modules/axfs/src/lib.rs index 4c69f5c745..addfa9c344 100644 --- a/modules/axfs/src/lib.rs +++ b/modules/axfs/src/lib.rs @@ -1,46 +1,30 @@ -//! [ArceOS](https://github.com/arceos-org/arceos) filesystem module. -//! -//! It provides unified filesystem operations for various filesystems. -//! -//! # Cargo Features -//! -//! - `fatfs`: Use [FAT] as the main filesystem and mount it on `/`. This feature -//! is **enabled** by default. -//! - `devfs`: Mount [`axfs_devfs::DeviceFileSystem`] on `/dev`. This feature is -//! **enabled** by default. -//! - `ramfs`: Mount [`axfs_ramfs::RamFileSystem`] on `/tmp`. This feature is -//! **enabled** by default. -//! - `myfs`: Allow users to define their custom filesystems to override the -//! default. In this case, [`MyFileSystemIf`] is required to be implemented -//! to create and initialize other filesystems. This feature is **disabled** by -//! by default, but it will override other filesystem selection features if -//! both are enabled. -//! -//! [FAT]: https://en.wikipedia.org/wiki/File_Allocation_Table -//! [`MyFileSystemIf`]: fops::MyFileSystemIf - -#![cfg_attr(all(not(test), not(doc)), no_std)] -#![feature(doc_auto_cfg)] +#![no_std] +#![allow(clippy::new_ret_no_self)] +#![feature(maybe_uninit_slice)] + +extern crate alloc; #[macro_use] extern crate log; -extern crate alloc; -mod dev; -mod fs; -mod mounts; -mod root; +use axdriver::{AxBlockDevice, AxDeviceContainer, prelude::*}; -pub mod api; -pub mod fops; +#[cfg(feature = "fat")] +mod disk; +mod fs; -use axdriver::{AxDeviceContainer, prelude::*}; +mod highlevel; +pub use highlevel::*; -/// Initializes filesystems by block devices. -pub fn init_filesystems(mut blk_devs: AxDeviceContainer) { - info!("Initialize filesystems..."); +pub fn init_filesystems(mut block_devs: AxDeviceContainer) { + info!("Initialize filesystem subsystem..."); - let dev = blk_devs.take_one().expect("No block device found!"); + let dev = block_devs.take_one().expect("No block device found!"); info!(" use block device 0: {:?}", dev.device_name()); - self::root::init_rootfs(self::dev::Disk::new(dev)); + + let fs = fs::new_default(dev).expect("Failed to initialize filesystem"); + info!(" filesystem type: {:?}", fs.name()); + + let mp = axfs_ng_vfs::Mountpoint::new_root(&fs); + ROOT_FS_CONTEXT.call_once(|| FsContext::new(mp.root_location())); } diff --git a/modules/axfs/src/mounts.rs b/modules/axfs/src/mounts.rs deleted file mode 100644 index aa9f434d5d..0000000000 --- a/modules/axfs/src/mounts.rs +++ /dev/null @@ -1,82 +0,0 @@ -use alloc::sync::Arc; -use axfs_vfs::{VfsNodeType, VfsOps, VfsResult}; - -use crate::fs; - -#[cfg(feature = "devfs")] -pub(crate) fn devfs() -> Arc { - let null = fs::devfs::NullDev; - let zero = fs::devfs::ZeroDev; - let urandom = fs::devfs::UrandomDev::default(); - let bar = fs::devfs::ZeroDev; - let devfs = fs::devfs::DeviceFileSystem::new(); - let foo_dir = devfs.mkdir("foo"); - devfs.add("null", Arc::new(null)); - devfs.add("zero", Arc::new(zero)); - devfs.add("urandom", Arc::new(urandom)); - foo_dir.add("bar", Arc::new(bar)); - Arc::new(devfs) -} - -#[cfg(feature = "ramfs")] -pub(crate) fn ramfs() -> Arc { - Arc::new(fs::ramfs::RamFileSystem::new()) -} - -#[cfg(feature = "procfs")] -pub(crate) fn procfs() -> VfsResult> { - let procfs = fs::ramfs::RamFileSystem::new(); - let proc_root = procfs.root_dir(); - - // Create /proc/sys/net/core/somaxconn - proc_root.create("sys", VfsNodeType::Dir)?; - proc_root.create("sys/net", VfsNodeType::Dir)?; - proc_root.create("sys/net/core", VfsNodeType::Dir)?; - proc_root.create("sys/net/core/somaxconn", VfsNodeType::File)?; - let file_somaxconn = proc_root.clone().lookup("./sys/net/core/somaxconn")?; - file_somaxconn.write_at(0, b"4096\n")?; - - // Create /proc/sys/vm/overcommit_memory - proc_root.create("sys/vm", VfsNodeType::Dir)?; - proc_root.create("sys/vm/overcommit_memory", VfsNodeType::File)?; - let file_over = proc_root.clone().lookup("./sys/vm/overcommit_memory")?; - file_over.write_at(0, b"0\n")?; - - // Create /proc/self/stat - proc_root.create("self", VfsNodeType::Dir)?; - proc_root.create("self/stat", VfsNodeType::File)?; - - Ok(Arc::new(procfs)) -} - -#[cfg(feature = "sysfs")] -pub(crate) fn sysfs() -> VfsResult> { - let sysfs = fs::ramfs::RamFileSystem::new(); - let sys_root = sysfs.root_dir(); - - // Create /sys/kernel/mm/transparent_hugepage/enabled - sys_root.create("kernel", VfsNodeType::Dir)?; - sys_root.create("kernel/mm", VfsNodeType::Dir)?; - sys_root.create("kernel/mm/transparent_hugepage", VfsNodeType::Dir)?; - sys_root.create("kernel/mm/transparent_hugepage/enabled", VfsNodeType::File)?; - let file_hp = sys_root - .clone() - .lookup("./kernel/mm/transparent_hugepage/enabled")?; - file_hp.write_at(0, b"always [madvise] never\n")?; - - // Create /sys/devices/system/clocksource/clocksource0/current_clocksource - sys_root.create("devices", VfsNodeType::Dir)?; - sys_root.create("devices/system", VfsNodeType::Dir)?; - sys_root.create("devices/system/clocksource", VfsNodeType::Dir)?; - sys_root.create("devices/system/clocksource/clocksource0", VfsNodeType::Dir)?; - sys_root.create( - "devices/system/clocksource/clocksource0/current_clocksource", - VfsNodeType::File, - )?; - let file_cc = sys_root - .clone() - .lookup("devices/system/clocksource/clocksource0/current_clocksource")?; - file_cc.write_at(0, b"tsc\n")?; - - Ok(Arc::new(sysfs)) -} diff --git a/modules/axfs/src/root.rs b/modules/axfs/src/root.rs deleted file mode 100644 index ac1016a4b5..0000000000 --- a/modules/axfs/src/root.rs +++ /dev/null @@ -1,331 +0,0 @@ -//! Root directory of the filesystem -//! -//! TODO: it doesn't work very well if the mount points have containment relationships. - -use alloc::{string::String, sync::Arc, vec::Vec}; -use axerrno::{AxError, AxResult, ax_err}; -use axfs_vfs::{VfsNodeAttr, VfsNodeOps, VfsNodeRef, VfsNodeType, VfsOps, VfsResult}; -use axsync::Mutex; -use lazyinit::LazyInit; -use scope_local::scope_local; - -use crate::{api::FileType, fs, mounts}; - -struct MountPoint { - path: &'static str, - fs: Arc, -} - -impl MountPoint { - pub fn new(path: &'static str, fs: Arc) -> Self { - Self { path, fs } - } -} - -impl Drop for MountPoint { - fn drop(&mut self) { - self.fs.umount().ok(); - } -} - -struct RootDirectory { - main_fs: Arc, - mounts: Vec, -} - -static ROOT_DIR: LazyInit> = LazyInit::new(); - -impl RootDirectory { - pub const fn new(main_fs: Arc) -> Self { - Self { - main_fs, - mounts: Vec::new(), - } - } - - pub fn mount(&mut self, path: &'static str, fs: Arc) -> AxResult { - if path == "/" { - return ax_err!(InvalidInput, "cannot mount root filesystem"); - } - if !path.starts_with('/') { - return ax_err!(InvalidInput, "mount path must start with '/'"); - } - if self.mounts.iter().any(|mp| mp.path == path) { - return ax_err!(InvalidInput, "mount point already exists"); - } - // create the mount point in the main filesystem if it does not exist - self.main_fs.root_dir().create(path, FileType::Dir)?; - fs.mount(path, self.main_fs.root_dir().lookup(path)?)?; - self.mounts.push(MountPoint::new(path, fs)); - Ok(()) - } - - pub fn _umount(&mut self, path: &str) { - self.mounts.retain(|mp| mp.path != path); - } - - pub fn contains(&self, path: &str) -> bool { - self.mounts.iter().any(|mp| mp.path == path) - } - - fn lookup_mounted_fs(&self, path: &str, f: F) -> AxResult - where - F: FnOnce(Arc, &str) -> AxResult, - { - debug!("lookup at root: {path}"); - let path = path.trim_matches('/'); - if let Some(rest) = path.strip_prefix("./") { - return self.lookup_mounted_fs(rest, f); - } - - let mut idx = 0; - let mut max_len = 0; - - // Find the filesystem that has the longest mounted path match - // TODO: more efficient, e.g. trie - for (i, mp) in self.mounts.iter().enumerate() { - // skip the first '/' - if path.starts_with(&mp.path[1..]) && mp.path.len() - 1 > max_len { - max_len = mp.path.len() - 1; - idx = i; - } - } - - if max_len == 0 { - f(self.main_fs.clone(), path) // not matched any mount point - } else { - f(self.mounts[idx].fs.clone(), &path[max_len..]) // matched at `idx` - } - } -} - -impl VfsNodeOps for RootDirectory { - axfs_vfs::impl_vfs_dir_default! {} - - fn get_attr(&self) -> VfsResult { - self.main_fs.root_dir().get_attr() - } - - fn lookup(self: Arc, path: &str) -> VfsResult { - self.lookup_mounted_fs(path, |fs, rest_path| fs.root_dir().lookup(rest_path)) - } - - fn create(&self, path: &str, ty: VfsNodeType) -> VfsResult { - self.lookup_mounted_fs(path, |fs, rest_path| { - if rest_path.is_empty() { - Ok(()) // already exists - } else { - fs.root_dir().create(rest_path, ty) - } - }) - } - - fn remove(&self, path: &str) -> VfsResult { - self.lookup_mounted_fs(path, |fs, rest_path| { - if rest_path.is_empty() { - ax_err!(PermissionDenied) // cannot remove mount points - } else { - fs.root_dir().remove(rest_path) - } - }) - } - - fn rename(&self, src_path: &str, dst_path: &str) -> VfsResult { - self.lookup_mounted_fs(src_path, |fs, rest_path| { - if rest_path.is_empty() { - ax_err!(PermissionDenied) // cannot rename mount points - } else { - fs.root_dir().rename(rest_path, dst_path) - } - }) - } -} - -#[derive(Clone)] -struct CurrentDir { - path: String, - node: VfsNodeRef, -} - -impl Default for CurrentDir { - fn default() -> Self { - Self { - path: String::from("/"), - node: ROOT_DIR.clone(), - } - } -} - -scope_local! { - static CURRENT_DIR: Mutex = Mutex::new(CurrentDir::default()); -} - -pub(crate) fn init_rootfs(disk: crate::dev::Disk) { - cfg_if::cfg_if! { - if #[cfg(feature = "myfs")] { // override the default filesystem - let main_fs = fs::myfs::new_myfs(disk); - } else if #[cfg(feature = "ext4fs")] { - static EXT4_FS: LazyInit> = LazyInit::new(); - EXT4_FS.init_once(Arc::new(fs::ext4fs::Ext4FileSystem::new(disk))); - let main_fs = EXT4_FS.clone(); - } else if #[cfg(feature = "fatfs")] { - static FAT_FS: LazyInit> = LazyInit::new(); - FAT_FS.init_once(Arc::new(fs::fatfs::FatFileSystem::new(disk))); - FAT_FS.init(); - let main_fs = FAT_FS.clone(); - } - } - - let mut root_dir = RootDirectory::new(main_fs); - - #[cfg(feature = "devfs")] - root_dir - .mount("/dev", mounts::devfs()) - .expect("failed to mount devfs at /dev"); - - #[cfg(feature = "ramfs")] - root_dir - .mount("/tmp", mounts::ramfs()) - .expect("failed to mount ramfs at /tmp"); - - // Mount another ramfs as procfs - #[cfg(feature = "procfs")] - root_dir // should not fail - .mount("/proc", mounts::procfs().unwrap()) - .expect("fail to mount procfs at /proc"); - - // Mount another ramfs as sysfs - #[cfg(feature = "sysfs")] - root_dir // should not fail - .mount("/sys", mounts::sysfs().unwrap()) - .expect("fail to mount sysfs at /sys"); - - ROOT_DIR.init_once(Arc::new(root_dir)); -} - -fn parent_node_of(dir: Option<&VfsNodeRef>, path: &str) -> VfsNodeRef { - if path.starts_with('/') { - ROOT_DIR.clone() - } else { - dir.cloned() - .unwrap_or_else(|| CURRENT_DIR.lock().node.clone()) - } -} - -pub(crate) fn absolute_path(path: &str) -> AxResult { - if path.starts_with('/') { - Ok(axfs_vfs::path::canonicalize(path)) - } else { - let path = CURRENT_DIR.lock().path.clone() + path; - Ok(axfs_vfs::path::canonicalize(&path)) - } -} - -pub(crate) fn lookup(dir: Option<&VfsNodeRef>, path: &str) -> AxResult { - if path.is_empty() { - return ax_err!(NotFound); - } - let node = parent_node_of(dir, path).lookup(path)?; - if path.ends_with('/') && !node.get_attr()?.is_dir() { - ax_err!(NotADirectory) - } else { - Ok(node) - } -} - -pub(crate) fn create_file(dir: Option<&VfsNodeRef>, path: &str) -> AxResult { - if path.is_empty() { - return ax_err!(NotFound); - } else if path.ends_with('/') { - return ax_err!(NotADirectory); - } - let parent = parent_node_of(dir, path); - parent.create(path, VfsNodeType::File)?; - parent.lookup(path) -} - -pub(crate) fn create_dir(dir: Option<&VfsNodeRef>, path: &str) -> AxResult { - match lookup(dir, path) { - Ok(_) => ax_err!(AlreadyExists), - Err(AxError::NotFound) => parent_node_of(dir, path).create(path, VfsNodeType::Dir), - Err(e) => Err(e), - } -} - -pub(crate) fn remove_file(dir: Option<&VfsNodeRef>, path: &str) -> AxResult { - let node = lookup(dir, path)?; - let attr = node.get_attr()?; - if attr.is_dir() { - ax_err!(IsADirectory) - } else if !attr.perm().owner_writable() { - ax_err!(PermissionDenied) - } else { - parent_node_of(dir, path).remove(path) - } -} - -pub(crate) fn remove_dir(dir: Option<&VfsNodeRef>, path: &str) -> AxResult { - if path.is_empty() { - return ax_err!(NotFound); - } - let path_check = path.trim_matches('/'); - if path_check.is_empty() { - return ax_err!(DirectoryNotEmpty); // rm -d '/' - } else if path_check == "." - || path_check == ".." - || path_check.ends_with("/.") - || path_check.ends_with("/..") - { - return ax_err!(InvalidInput); - } - if ROOT_DIR.contains(&absolute_path(path)?) { - return ax_err!(PermissionDenied); - } - - let node = lookup(dir, path)?; - let attr = node.get_attr()?; - if !attr.is_dir() { - ax_err!(NotADirectory) - } else if !attr.perm().owner_writable() { - ax_err!(PermissionDenied) - } else { - parent_node_of(dir, path).remove(path) - } -} - -pub(crate) fn current_dir() -> AxResult { - Ok(CURRENT_DIR.lock().path.clone()) -} - -pub(crate) fn set_current_dir(path: &str) -> AxResult { - let mut abs_path = absolute_path(path)?; - if !abs_path.ends_with('/') { - abs_path += "/"; - } - if abs_path == "/" { - *CURRENT_DIR.lock() = CurrentDir::default(); - return Ok(()); - } - - let node = lookup(None, &abs_path)?; - let attr = node.get_attr()?; - if !attr.is_dir() { - ax_err!(NotADirectory) - } else if !attr.perm().owner_executable() { - ax_err!(PermissionDenied) - } else { - *CURRENT_DIR.lock() = CurrentDir { - path: abs_path, - node, - }; - Ok(()) - } -} - -pub(crate) fn rename(old: &str, new: &str) -> AxResult { - if parent_node_of(None, new).lookup(new).is_ok() { - warn!("dst file already exist, now remove it"); - remove_file(None, new)?; - } - parent_node_of(None, old).rename(old, new) -} diff --git a/modules/axfs/tests/test_common/mod.rs b/modules/axfs/tests/test_common/mod.rs deleted file mode 100644 index 8cbf40e036..0000000000 --- a/modules/axfs/tests/test_common/mod.rs +++ /dev/null @@ -1,262 +0,0 @@ -use axfs::api as fs; -use axio as io; - -use fs::{File, FileType, OpenOptions}; -use io::{Error, Result, prelude::*}; - -macro_rules! assert_err { - ($expr: expr) => { - assert!(($expr).is_err()) - }; - ($expr: expr, $err: ident) => { - assert_eq!(($expr).err(), Some(Error::$err)) - }; -} - -fn test_read_write_file() -> Result<()> { - let fname = "///very/long//.././long//./path/./test.txt"; - println!("read and write file {fname:?}:"); - - // read and write - let mut file = File::options().read(true).write(true).open(fname)?; - let file_size = file.metadata()?.len(); - let mut contents = String::new(); - file.read_to_string(&mut contents)?; - print!("{contents}"); - assert_eq!(contents.len(), file_size as usize); - assert_eq!(file.write(b"Hello, world!\n")?, 14); // append - drop(file); - - // read again and check - let new_contents = fs::read_to_string(fname)?; - print!("{new_contents}"); - assert_eq!(new_contents, contents + "Hello, world!\n"); - - // append and check - let mut file = OpenOptions::new().append(true).open(fname)?; - assert_eq!(file.write(b"new line\n")?, 9); - drop(file); - - let new_contents2 = fs::read_to_string(fname)?; - print!("{new_contents2}"); - assert_eq!(new_contents2, new_contents + "new line\n"); - - // open a non-exist file - assert_err!(File::open("/not/exist/file"), NotFound); - - println!("test_read_write_file() OK!"); - Ok(()) -} - -fn test_read_dir() -> Result<()> { - let dir = "/././//./"; - println!("list directory {dir:?}:"); - for entry in fs::read_dir(dir)? { - let entry = entry?; - println!(" {}", entry.file_name()); - } - println!("test_read_dir() OK!"); - Ok(()) -} - -fn test_file_permission() -> Result<()> { - let fname = "./short.txt"; - println!("test permission {fname:?}:"); - - // write a file that open with read-only mode - let mut buf = [0; 256]; - let mut file = File::open(fname)?; - let n = file.read(&mut buf)?; - assert_err!(file.write(&buf), PermissionDenied); - drop(file); - - // read a file that open with write-only mode - let mut file = File::create(fname)?; - assert_err!(file.read(&mut buf), PermissionDenied); - assert!(file.write(&buf[..n]).is_ok()); - drop(file); - - // open with empty options - assert_err!(OpenOptions::new().open(fname), InvalidInput); - - // read as a directory - assert_err!(fs::read_dir(fname), NotADirectory); - assert_err!(fs::read("short.txt/"), NotADirectory); - assert_err!(fs::metadata("/short.txt/"), NotADirectory); - - // create as a directory - assert_err!(fs::write("error/", "should not create"), NotADirectory); - assert_err!(fs::metadata("error/"), NotFound); - assert_err!(fs::metadata("error"), NotFound); - - // read/write a directory - assert_err!(fs::read_to_string("/dev"), IsADirectory); - assert_err!(fs::write(".", "test"), IsADirectory); - - println!("test_file_permisson() OK!"); - Ok(()) -} - -fn test_create_file_dir() -> Result<()> { - // create a file and test existence - let fname = "././/very-long-dir-name/..///new-file.txt"; - println!("test create file {fname:?}:"); - assert_err!(fs::metadata(fname), NotFound); - let contents = "create a new file!\n"; - fs::write(fname, contents)?; - - let dirents = fs::read_dir(".")? - .map(|e| e.unwrap().file_name()) - .collect::>(); - println!("dirents = {dirents:?}"); - assert!(dirents.contains(&"new-file.txt".into())); - assert_eq!(fs::read_to_string(fname)?, contents); - assert_err!(File::create_new(fname), AlreadyExists); - - // create a directory and test existence - let dirname = "///././/very//.//long/./new-dir"; - println!("test create dir {dirname:?}:"); - assert_err!(fs::metadata(dirname), NotFound); - fs::create_dir(dirname)?; - - let dirents = fs::read_dir("./very/long")? - .map(|e| e.unwrap().file_name()) - .collect::>(); - println!("dirents = {dirents:?}"); - assert!(dirents.contains(&"new-dir".into())); - assert!(fs::metadata(dirname)?.is_dir()); - assert_err!(fs::create_dir(dirname), AlreadyExists); - - println!("test_create_file_dir() OK!"); - Ok(()) -} - -fn test_remove_file_dir() -> Result<()> { - // remove a file and test existence - let fname = "//very-long-dir-name/..///new-file.txt"; - println!("test remove file {fname:?}:"); - assert_err!(fs::remove_dir(fname), NotADirectory); - assert!(fs::remove_file(fname).is_ok()); - assert_err!(fs::metadata(fname), NotFound); - assert_err!(fs::remove_file(fname), NotFound); - - // remove a directory and test existence - let dirname = "very//.//long/../long/.//./new-dir////"; - println!("test remove dir {dirname:?}:"); - assert_err!(fs::remove_file(dirname), IsADirectory); - assert!(fs::remove_dir(dirname).is_ok()); - assert_err!(fs::metadata(dirname), NotFound); - assert_err!(fs::remove_dir(fname), NotFound); - - // error cases - assert_err!(fs::remove_file(""), NotFound); - assert_err!(fs::remove_dir("/"), DirectoryNotEmpty); - assert_err!(fs::remove_dir("."), InvalidInput); - assert_err!(fs::remove_dir("../"), InvalidInput); - assert_err!(fs::remove_dir("./././/"), InvalidInput); - assert_err!(fs::remove_file("///very/./"), IsADirectory); - assert_err!(fs::remove_file("short.txt/"), NotADirectory); - assert_err!(fs::remove_dir(".///"), InvalidInput); - assert_err!(fs::remove_dir("/./very///"), DirectoryNotEmpty); - assert_err!(fs::remove_dir("very/long/.."), InvalidInput); - - println!("test_remove_file_dir() OK!"); - Ok(()) -} - -fn test_devfs_ramfs() -> Result<()> { - const N: usize = 32; - let mut buf = [1; N]; - - // list '/' and check if /dev and /tmp exist - let dirents = fs::read_dir("././//.//")? - .map(|e| e.unwrap().file_name()) - .collect::>(); - assert!(dirents.contains(&"dev".into())); - assert!(dirents.contains(&"tmp".into())); - - // read and write /dev/null - let mut file = File::options().read(true).write(true).open("/dev/./null")?; - assert_eq!(file.read_to_end(&mut Vec::new())?, 0); - assert_eq!(file.write(&buf)?, N); - assert_eq!(buf, [1; N]); - - // read and write /dev/zero - let mut file = OpenOptions::new() - .read(true) - .write(true) - .create(true) - .truncate(true) - .open("////dev/zero")?; - assert_eq!(file.read(&mut buf)?, N); - assert!(file.write_all(&buf).is_ok()); - assert_eq!(buf, [0; N]); - - // list /dev - let dirents = fs::read_dir("/dev")? - .map(|e| e.unwrap().file_name()) - .collect::>(); - assert!(dirents.contains(&"null".into())); - assert!(dirents.contains(&"zero".into())); - - // stat /dev - let dname = "/dev"; - let dir = File::open(dname)?; - let md = dir.metadata()?; - println!("metadata of {dname:?}: {md:?}"); - assert_eq!(md.file_type(), FileType::Dir); - assert!(!md.is_file()); - assert!(md.is_dir()); - - // stat /dev/foo/bar - let fname = ".//.///././/./dev///.///./foo//././bar"; - let file = File::open(fname)?; - let md = file.metadata()?; - println!("metadata of {fname:?}: {md:?}"); - assert_eq!(md.file_type(), FileType::CharDevice); - assert!(!md.is_dir()); - - // error cases - assert_err!(fs::metadata("/dev/null/"), NotADirectory); - assert_err!(fs::create_dir("dev"), AlreadyExists); - assert_err!(File::create_new("/dev/"), AlreadyExists); - assert_err!(fs::create_dir("/dev/zero"), AlreadyExists); - assert_err!(fs::write("/dev/stdout", "test"), PermissionDenied); - assert_err!(fs::create_dir("/dev/test"), PermissionDenied); - assert_err!(fs::remove_file("/dev/null"), PermissionDenied); - assert_err!(fs::remove_dir("./dev"), PermissionDenied); - assert_err!(fs::remove_dir("./dev/."), InvalidInput); - assert_err!(fs::remove_dir("///dev//..//"), InvalidInput); - - // parent of '/dev' - assert_eq!(fs::create_dir("///dev//..//233//"), Ok(())); - assert_eq!(fs::write(".///dev//..//233//.///test.txt", "test"), Ok(())); - assert_err!(fs::remove_file("./dev//../..//233//.///test.txt"), NotFound); - assert_eq!(fs::remove_file("./dev//..//233//../233/./test.txt"), Ok(())); - assert_eq!(fs::remove_dir("dev//foo/../foo/../.././/233"), Ok(())); - assert_err!(fs::remove_dir("very/../dev//"), PermissionDenied); - - // tests in /tmp - assert_eq!(fs::metadata("tmp")?.file_type(), FileType::Dir); - assert_eq!(fs::create_dir(".///tmp///././dir"), Ok(())); - assert_eq!(fs::read_dir("tmp").unwrap().count(), 1); - assert_eq!(fs::write(".///tmp///dir//.///test.txt", "test"), Ok(())); - assert_eq!(fs::read("tmp//././/dir//.///test.txt"), Ok("test".into())); - // assert_err!(fs::remove_dir("dev/../tmp//dir"), DirectoryNotEmpty); // TODO - assert_err!(fs::remove_dir("/tmp/dir/../dir"), DirectoryNotEmpty); - assert_eq!(fs::remove_file("./tmp//dir//test.txt"), Ok(())); - assert_eq!(fs::remove_dir("tmp/dir/.././dir///"), Ok(())); - assert_eq!(fs::read_dir("tmp").unwrap().count(), 0); - - println!("test_devfs_ramfs() OK!"); - Ok(()) -} - -pub fn test_all() { - test_read_write_file().expect("test_read_write_file() failed"); - test_read_dir().expect("test_read_dir() failed"); - test_file_permission().expect("test_file_permission() failed"); - test_create_file_dir().expect("test_create_file_dir() failed"); - test_remove_file_dir().expect("test_remove_file_dir() failed"); - test_devfs_ramfs().expect("test_devfs_ramfs() failed"); -} diff --git a/modules/axfs/tests/test_fatfs.rs b/modules/axfs/tests/test_fatfs.rs deleted file mode 100644 index 481c5c8cf5..0000000000 --- a/modules/axfs/tests/test_fatfs.rs +++ /dev/null @@ -1,27 +0,0 @@ -#![cfg(not(feature = "myfs"))] - -mod test_common; - -use axdriver::AxDeviceContainer; -use axdriver_block::ramdisk::RamDisk; - -const IMG_PATH: &str = "resources/fat16.img"; - -fn make_disk() -> std::io::Result { - let path = std::env::current_dir()?.join(IMG_PATH); - println!("Loading disk image from {path:?} ..."); - let data = std::fs::read(path)?; - println!("size = {} bytes", data.len()); - Ok(RamDisk::from(data.as_slice())) -} - -#[test] -fn test_fatfs() { - println!("Testing fatfs with ramdisk ..."); - - let disk = make_disk().expect("failed to load disk image"); - axtask::init_scheduler(); // call this to use `axsync::Mutex`. - axfs::init_filesystems(AxDeviceContainer::from_one(disk)); - - test_common::test_all(); -} diff --git a/modules/axfs/tests/test_fs.rs b/modules/axfs/tests/test_fs.rs new file mode 100644 index 0000000000..5a0a5de3a9 --- /dev/null +++ b/modules/axfs/tests/test_fs.rs @@ -0,0 +1,175 @@ +// FIXME: The test is broken after we changed the ramdisk implementation! + +#![allow(unused)] + +use std::collections::HashSet; + +use axdriver_block::ramdisk::RamDisk; +use axfs::{File, FsContext, fs}; +use axfs_ng_vfs::{ + Filesystem, Location, Mountpoint, NodePermission, NodeType, VfsError, VfsResult, path::Path, +}; +use axio::Read; + +type RawMutex = spin::Mutex<()>; + +fn list_files(cx: &FsContext, path: impl AsRef) -> VfsResult> { + cx.read_dir(path)? + .map(|it| it.map(|entry| entry.name.to_owned())) + .collect() +} + +fn test_fs_read(fs: &Filesystem) -> VfsResult<()> { + let mount = Mountpoint::new_root(fs); + let cx: FsContext> = FsContext::new(mount.root_location()); + + let names = list_files(&cx, "/").unwrap(); + assert!( + ["short.txt", "long.txt", "a", "very-long-dir-name"] + .into_iter() + .all(|it| names.contains(it)) + ); + assert_eq!(cx.metadata("short.txt")?.size, 14); + assert_eq!(cx.metadata("long.txt")?.size, 14000); + + let entries = cx.read_dir("/")?.collect::>>()?; + for entry in entries { + assert!(cx.root_dir().lookup_no_follow(&entry.name)?.inode() == entry.ino); + } + + assert_eq!( + list_files(&cx, "/a/long/path")?, + ["test.txt", ".", ".."] + .into_iter() + .map(str::to_owned) + .collect() + ); + assert_eq!( + cx.read_to_string("/a/long/path/test.txt")?, + "Rust is cool!\n" + ); + + assert_eq!( + cx.resolve("/a/long/path/test.txt")? + .absolute_path()? + .to_string(), + "/a/long/path/test.txt" + ); + + assert!( + cx.resolve("/very-long-dir-name/very-long-file-name.txt")? + .is_file() + ); + let mut file = File::open(&cx, "/very-long-dir-name/very-long-file-name.txt")?; + let mut buf = vec![]; + file.read_to_end(&mut buf)?; + drop(file); + assert_eq!(core::str::from_utf8(&buf).unwrap(), "Rust is cool!\n"); + + Ok(()) +} + +fn test_fs_write(fs: &Filesystem) -> VfsResult<()> { + let mount = Mountpoint::new_root(fs); + let cx = FsContext::new(mount.root_location()); + + let mode = NodePermission::from_bits(0o766).unwrap(); + cx.create_dir("temp", mode)?; + cx.create_dir("temp2", mode)?; + assert!(cx.resolve("temp").is_ok() && cx.resolve("temp2").is_ok()); + // cx.rename("temp", "temp2")?; + // assert!(cx.resolve("temp").is_err() && cx.resolve("temp2").is_ok()); + + cx.create_dir("temp", mode)?; + cx.resolve("temp")? + .create("test.txt", NodeType::RegularFile, NodePermission::default())?; + assert!(matches!( + cx.rename("temp2", "temp"), + Err(VfsError::ENOTEMPTY) + )); + + cx.write("/test.txt", "hello world".as_bytes())?; + assert_eq!(cx.read_to_string("/test.txt")?, "hello world"); + + cx.create_dir("test_dir", NodePermission::from_bits_truncate(0o755))?; + cx.rename("test_dir", "test")?; + cx.remove_dir("test")?; + + println!("---------------------"); + + if cx.link("/test.txt", "/test_link").is_ok() { + assert_eq!(cx.read_to_string("/test_link")?, "hello world"); + } + if cx.symlink("/test.txt", "/test_symlink").is_ok() { + assert_eq!(cx.read_to_string("/test_symlink")?, "hello world"); + } + + // FAT has errornous rename implementation + if fs.name() != "vfat" { + cx.write("rename1", "hello world".as_bytes())?; + cx.write("rename2", "hello world2".as_bytes())?; + cx.rename("rename1", "rename2")?; + assert_eq!(cx.read_to_string("rename2")?, "hello world"); + } + + Ok(()) +} + +fn test_fs_full(fs: Filesystem) -> VfsResult<()> { + let mut thrds = vec![]; + for _ in 0..1 { + let fs = fs.clone(); + thrds.push(std::thread::spawn(move || test_fs_read(&fs))); + } + for th in thrds { + th.join().unwrap()?; + } + test_fs_write(&fs)?; + Ok(()) +} + +#[test] +#[cfg(feature = "fat")] +fn test_fatfs() { + for path in ["resources/fat16.img", "resources/fat32.img"] { + let data = std::fs::read(path).unwrap(); + let disk = RamDisk::from(&data); + let fs = fs::fat::FatFilesystem::::new(disk); + test_fs_full(fs).unwrap(); + } +} + +#[test] +#[cfg(feature = "ext4")] +fn test_ext4() { + let data = std::fs::read("resources/ext4.img").unwrap(); + let disk = RamDisk::from(&data); + let fs = fs::ext4::Ext4Filesystem::::new(disk).unwrap(); + test_fs_full(fs).unwrap(); +} + +#[test] +#[cfg(all(feature = "ext4", feature = "fat"))] +fn test_mount() { + env_logger::init(); + let disk = RamDisk::from(&std::fs::read("resources/ext4.img").unwrap()); + let fs = fs::ext4::Ext4Filesystem::::new(disk).unwrap(); + + let disk = RamDisk::from(&std::fs::read("resources/fat16.img").unwrap()); + let sub_fs = fs::fat::FatFilesystem::::new(disk); + + let mount = Mountpoint::new(&fs, None); + let cx = FsContext::new(mount.root_location()); + cx.resolve("a").unwrap().mount(&sub_fs); + + let mt = cx.resolve("a").unwrap(); + assert!(!mt.is_mountpoint() && mt.is_root_of_mount()); + assert_eq!(mt.filesystem().name(), "vfat"); + assert_eq!(mt.absolute_path().unwrap().to_string(), "/a"); + + assert_eq!( + cx.read_to_string("/a/../a/very-long-dir-name/very-long-file-name.txt") + .unwrap(), + "Rust is cool!\n" + ); +} diff --git a/modules/axfs/tests/test_ramfs.rs b/modules/axfs/tests/test_ramfs.rs deleted file mode 100644 index 914c245bc6..0000000000 --- a/modules/axfs/tests/test_ramfs.rs +++ /dev/null @@ -1,56 +0,0 @@ -#![cfg(feature = "myfs")] - -mod test_common; - -use std::sync::Arc; - -use axdriver::AxDeviceContainer; -use axdriver_block::ramdisk::RamDisk; -use axfs::api::{self as fs, File}; -use axfs::fops::{Disk, MyFileSystemIf}; -use axfs_ramfs::RamFileSystem; -use axfs_vfs::VfsOps; -use axio::{Result, Write}; - -struct MyFileSystemIfImpl; - -#[crate_interface::impl_interface] -impl MyFileSystemIf for MyFileSystemIfImpl { - fn new_myfs(_disk: Disk) -> Arc { - Arc::new(RamFileSystem::new()) - } -} - -fn create_init_files() -> Result<()> { - fs::write("./short.txt", "Rust is cool!\n")?; - let mut file = File::create_new("/long.txt")?; - for _ in 0..100 { - file.write_fmt(format_args!("Rust is cool!\n"))?; - } - - fs::create_dir("very-long-dir-name")?; - fs::write( - "very-long-dir-name/very-long-file-name.txt", - "Rust is cool!\n", - )?; - - fs::create_dir("very")?; - fs::create_dir("//very/long")?; - fs::create_dir("/./very/long/path")?; - fs::write(".//very/long/path/test.txt", "Rust is cool!\n")?; - Ok(()) -} - -#[test] -fn test_ramfs() { - println!("Testing ramfs ..."); - - axtask::init_scheduler(); // call this to use `axsync::Mutex`. - axfs::init_filesystems(AxDeviceContainer::from_one(RamDisk::default())); // dummy disk, actually not used. - - if let Err(e) = create_init_files() { - log::warn!("failed to create init files: {:?}", e); - } - - test_common::test_all(); -} diff --git a/modules/axhal/src/paging.rs b/modules/axhal/src/paging.rs index 3b51e6d9ca..075c8b2520 100644 --- a/modules/axhal/src/paging.rs +++ b/modules/axhal/src/paging.rs @@ -3,14 +3,13 @@ use axalloc::{UsageKind, global_allocator}; use memory_addr::{PAGE_SIZE_4K, PhysAddr, VirtAddr}; use page_table_multiarch::PagingHandler; - -use crate::mem::{phys_to_virt, virt_to_phys}; - #[doc(no_inline)] pub use page_table_multiarch::{MappingFlags, PageSize, PagingError, PagingResult}; -/// Implementation of [`PagingHandler`], to provide physical memory manipulation to -/// the [page_table_multiarch] crate. +use crate::mem::{phys_to_virt, virt_to_phys}; + +/// Implementation of [`PagingHandler`], to provide physical memory manipulation +/// to the [page_table_multiarch] crate. pub struct PagingHandlerImpl; impl PagingHandler for PagingHandlerImpl { @@ -35,14 +34,18 @@ cfg_if::cfg_if! { if #[cfg(target_arch = "x86_64")] { /// The architecture-specific page table. pub type PageTable = page_table_multiarch::x86_64::X64PageTable; + pub type PageTableMut<'a> = page_table_multiarch::x86_64::X64PageTableMut<'a, PagingHandlerImpl>; } else if #[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] { /// The architecture-specific page table. pub type PageTable = page_table_multiarch::riscv::Sv39PageTable; + pub type PageTableMut<'a> = page_table_multiarch::riscv::Sv39PageTableMut<'a, PagingHandlerImpl>; } else if #[cfg(target_arch = "aarch64")]{ /// The architecture-specific page table. pub type PageTable = page_table_multiarch::aarch64::A64PageTable; + pub type PageTableMut<'a> = page_table_multiarch::aarch64::A64PageTableMut<'a, PagingHandlerImpl>; } else if #[cfg(target_arch = "loongarch64")] { /// The architecture-specific page table. pub type PageTable = page_table_multiarch::loongarch64::LA64PageTable; + pub type PageTableMut<'a> = page_table_multiarch::loongarch64::LA64PageTableMut<'a, PagingHandlerImpl>; } } diff --git a/modules/axmm/Cargo.toml b/modules/axmm/Cargo.toml index be58e27399..760d7f7d6c 100644 --- a/modules/axmm/Cargo.toml +++ b/modules/axmm/Cargo.toml @@ -10,16 +10,22 @@ repository = "https://github.com/arceos-org/arceos/tree/main/modules/axmm" documentation = "https://arceos-org.github.io/arceos/axmm/index.html" [features] +default = [] copy = ["page_table_multiarch/copy-from"] [dependencies] -axalloc.workspace = true -axconfig.workspace = true -axerrno.workspace = true +axalloc = { workspace = true } +axconfig = { workspace = true } +axerrno = { workspace = true } +axfs = { workspace = true } +axfs-ng-vfs = { workspace = true } axhal = { workspace = true, features = ["paging"] } -kspin.workspace = true -lazyinit.workspace = true -log.workspace = true -memory_addr.workspace = true +axsync = { workspace = true } +axtask = { workspace = true } +enum_dispatch = { workspace = true } +kspin = { workspace = true } +lazyinit = { workspace = true } +log = { workspace = true } +memory_addr = { workspace = true } memory_set = { version = "0.4", features = ["axerrno"] } page_table_multiarch = { workspace = true, optional = true } diff --git a/modules/axmm/src/aspace.rs b/modules/axmm/src/aspace.rs index 61ff253afa..2db594db95 100644 --- a/modules/axmm/src/aspace.rs +++ b/modules/axmm/src/aspace.rs @@ -1,15 +1,19 @@ -use core::fmt; - -use axerrno::{AxError, AxResult, ax_err}; -use axhal::mem::phys_to_virt; -use axhal::paging::{MappingFlags, PageTable}; -use axhal::trap::PageFaultFlags; +use alloc::sync::Arc; +use core::{fmt, ops::DerefMut}; + +use axerrno::{AxError, AxResult, ax_bail}; +use axhal::{ + mem::phys_to_virt, + paging::{MappingFlags, PageTable}, + trap::PageFaultFlags, +}; +use axsync::Mutex; use memory_addr::{ MemoryAddr, PAGE_SIZE_4K, PageIter4K, PhysAddr, VirtAddr, VirtAddrRange, is_aligned_4k, }; use memory_set::{MemoryArea, MemorySet}; -use crate::backend::Backend; +use crate::backend::{Backend, BackendOps}; /// The virtual memory address space. pub struct AddrSpace { @@ -39,6 +43,11 @@ impl AddrSpace { &self.pt } + /// Returns a mutable reference to the inner page table. + pub const fn page_table_mut(&mut self) -> &mut PageTable { + &mut self.pt + } + /// Returns the root physical address of the inner page table. pub const fn page_table_root(&self) -> PhysAddr { self.pt.root_paddr() @@ -46,12 +55,11 @@ impl AddrSpace { /// Checks if the address space contains the given address range. pub fn contains_range(&self, start: VirtAddr, size: usize) -> bool { - self.va_range - .contains_range(VirtAddrRange::from_start_size(start, size)) + self.va_range.contains(start) && (self.va_range.end - start) >= size } /// Creates a new empty address space. - pub(crate) fn new_empty(base: VirtAddr, size: usize) -> AxResult { + pub fn new_empty(base: VirtAddr, size: usize) -> AxResult { Ok(Self { va_range: VirtAddrRange::from_start_size(base, size), areas: MemorySet::new(), @@ -68,25 +76,41 @@ impl AddrSpace { /// Returns an error if the two address spaces overlap. #[cfg(feature = "copy")] pub fn copy_mappings_from(&mut self, other: &AddrSpace) -> AxResult { - if self.va_range.overlaps(other.va_range) { - return ax_err!(InvalidInput, "address space overlap"); + self.pt + .modify() + .copy_from(&other.pt, other.base(), other.size()); + Ok(()) + } + + fn validate_region(&self, start: VirtAddr, size: usize) -> AxResult { + if !self.contains_range(start, size) { + ax_bail!(NoMemory, "address out of range"); + } + if !start.is_aligned_4k() || !is_aligned_4k(size) { + ax_bail!(InvalidInput, "address is not aligned"); } - self.pt.copy_from(&other.pt, other.base(), other.size()); Ok(()) } /// Finds a free area that can accommodate the given size. /// - /// The search starts from the given hint address, and the area should be within the given limit range. + /// The search starts from the given hint address, and the area should be + /// within the given limit range. /// - /// Returns the start address of the free area. Returns None if no such area is found. + /// Returns the start address of the free area. Returns None if no such area + /// is found. pub fn find_free_area( &self, hint: VirtAddr, size: usize, limit: VirtAddrRange, + align: usize, ) -> Option { - self.areas.find_free_area(hint, size, limit, PAGE_SIZE_4K) + self.areas.find_free_area(hint, size, limit, align) + } + + pub fn find_area(&self, vaddr: VirtAddr) -> Option<&MemoryArea> { + self.areas.find(vaddr) } /// Add a new linear mapping. @@ -104,43 +128,64 @@ impl AddrSpace { size: usize, flags: MappingFlags, ) -> AxResult { - if !self.contains_range(start_vaddr, size) { - return ax_err!(InvalidInput, "address out of range"); - } - if !start_vaddr.is_aligned_4k() || !start_paddr.is_aligned_4k() || !is_aligned_4k(size) { - return ax_err!(InvalidInput, "address not aligned"); + self.validate_region(start_vaddr, size)?; + + if !start_paddr.is_aligned_4k() { + ax_bail!(InvalidInput, "address is not aligned"); } - let offset = start_vaddr.as_usize() - start_paddr.as_usize(); + let offset = start_vaddr.as_usize() as isize - start_paddr.as_usize() as isize; let area = MemoryArea::new(start_vaddr, size, flags, Backend::new_linear(offset)); self.areas.map(area, &mut self.pt, false)?; Ok(()) } - /// Add a new allocation mapping. - /// - /// See [`Backend`] for more details about the mapping backends. - /// - /// The `flags` parameter indicates the mapping permissions and attributes. - /// - /// Returns an error if the address range is out of the address space or not - /// aligned. - pub fn map_alloc( + pub fn map( &mut self, start: VirtAddr, size: usize, flags: MappingFlags, populate: bool, + backend: Backend, ) -> AxResult { - if !self.contains_range(start, size) { - return ax_err!(InvalidInput, "address out of range"); + self.validate_region(start, size)?; + + let area = MemoryArea::new(start, size, flags, backend); + self.areas.map(area, &mut self.pt, false)?; + if populate { + self.populate_area(start, size, flags)?; } - if !start.is_aligned_4k() || !is_aligned_4k(size) { - return ax_err!(InvalidInput, "address not aligned"); + Ok(()) + } + + /// Populates the area with physical frames, returning false if the area + /// contains unmapped area. + pub fn populate_area( + &mut self, + mut start: VirtAddr, + size: usize, + access_flags: MappingFlags, + ) -> AxResult { + self.validate_region(start, size)?; + let end = start + size; + + let mut modify = self.pt.modify(); + while let Some(area) = self.areas.find(start) { + let range = VirtAddrRange::new(start, area.end().min(end)); + area.backend() + .populate(range, area.flags(), access_flags, &mut modify)?; + start = area.end(); + assert!(start.is_aligned_4k()); + if start >= end { + break; + } + } + + if start < end { + // If the area is not fully mapped, we return ENOMEM. + ax_bail!(NoMemory); } - let area = MemoryArea::new(start, size, flags, Backend::new_alloc(populate)); - self.areas.map(area, &mut self.pt, false)?; Ok(()) } @@ -149,12 +194,7 @@ impl AddrSpace { /// Returns an error if the address range is out of the address space or not /// aligned. pub fn unmap(&mut self, start: VirtAddr, size: usize) -> AxResult { - if !self.contains_range(start, size) { - return ax_err!(InvalidInput, "address out of range"); - } - if !start.is_aligned_4k() || !is_aligned_4k(size) { - return ax_err!(InvalidInput, "address not aligned"); - } + self.validate_region(start, size)?; self.areas.unmap(start, size, &mut self.pt)?; Ok(()) @@ -168,7 +208,7 @@ impl AddrSpace { F: FnMut(VirtAddr, usize, usize), { if !self.contains_range(start, size) { - return ax_err!(InvalidInput, "address out of range"); + ax_bail!(InvalidInput, "address out of range"); } let mut cnt = 0; // If start is aligned to 4K, start_align_down will be equal to start_align_up. @@ -176,7 +216,7 @@ impl AddrSpace { for vaddr in PageIter4K::new(start.align_down_4k(), end_align_up) .expect("Failed to create page iterator") { - let (mut paddr, _, _) = self.pt.query(vaddr).map_err(|_| AxError::BadAddress)?; + let (mut paddr, ..) = self.pt.query(vaddr).map_err(|_| AxError::BadAddress)?; let mut copy_size = (size - cnt).min(PAGE_SIZE_4K); @@ -223,18 +263,11 @@ impl AddrSpace { /// Returns an error if the address range is out of the address space or not /// aligned. pub fn protect(&mut self, start: VirtAddr, size: usize, flags: MappingFlags) -> AxResult { - if !self.contains_range(start, size) { - return ax_err!(InvalidInput, "address out of range"); - } - if !start.is_aligned_4k() || !is_aligned_4k(size) { - return ax_err!(InvalidInput, "address not aligned"); - } + self.validate_region(start, size)?; + + self.areas + .protect(start, size, |_| Some(flags), &mut self.pt)?; - // TODO - self.pt - .protect_region(start, size, flags, true) - .map_err(|_| AxError::BadState)? - .ignore(); Ok(()) } @@ -253,7 +286,9 @@ impl AddrSpace { size: usize, access_flags: MappingFlags, ) -> bool { - let mut range = VirtAddrRange::from_start_size(start, size); + let Some(mut range) = VirtAddrRange::try_from_start_size(start, size) else { + return false; + }; for area in self.areas.iter() { if area.end() <= range.start { continue; @@ -287,15 +322,66 @@ impl AddrSpace { return false; } if let Some(area) = self.areas.find(vaddr) { - let orig_flags = area.flags(); - if orig_flags.contains(access_flags) { - return area - .backend() - .handle_page_fault(vaddr, orig_flags, &mut self.pt); + let flags = area.flags(); + if flags.contains(access_flags) { + let page_size = area.backend().page_size(); + let populate_result = area.backend().populate( + VirtAddrRange::from_start_size(vaddr.align_down(page_size), page_size as _), + flags, + access_flags, + &mut self.pt.modify(), + ); + return match populate_result { + Ok((n, callback)) => { + if let Some(cb) = callback { + cb(self); + } + if n == 0 { + warn!("No pages populated for {vaddr:?} ({flags:?})"); + false + } else { + true + } + } + Err(err) => { + warn!("Failed to populate pages for {vaddr:?} ({flags:?}): {err}"); + false + } + }; } } false } + + /// Attempts to clone the current address space into a new one. + /// + /// This method creates a new empty address space with the same base and + /// size, then iterates over all memory areas in the original address + /// space to copy or share their mappings into the new one. + pub fn try_clone(&mut self) -> AxResult>> { + let new_aspace = Arc::new(Mutex::new(Self::new_empty(self.base(), self.size())?)); + let new_aspace_clone = new_aspace.clone(); + + let mut guard = new_aspace.lock(); + + let mut self_modify = self.pt.modify(); + for area in self.areas.iter() { + let new_backend = area.backend().clone_map( + area.va_range(), + area.flags(), + &mut self_modify, + &mut guard.pt.modify(), + &new_aspace_clone, + )?; + + let new_area = MemoryArea::new(area.start(), area.size(), area.flags(), new_backend); + let aspace = guard.deref_mut(); + aspace.areas.map(new_area, &mut aspace.pt, false)?; + } + drop(guard); + + Ok(new_aspace) + } } impl fmt::Debug for AddrSpace { diff --git a/modules/axmm/src/backend/alloc.rs b/modules/axmm/src/backend/alloc.rs deleted file mode 100644 index 19b43443c1..0000000000 --- a/modules/axmm/src/backend/alloc.rs +++ /dev/null @@ -1,112 +0,0 @@ -use axalloc::{UsageKind, global_allocator}; -use axhal::mem::{phys_to_virt, virt_to_phys}; -use axhal::paging::{MappingFlags, PageSize, PageTable}; -use memory_addr::{PAGE_SIZE_4K, PageIter4K, PhysAddr, VirtAddr}; - -use super::Backend; - -fn alloc_frame(zeroed: bool) -> Option { - let vaddr = VirtAddr::from( - global_allocator() - .alloc_pages(1, PAGE_SIZE_4K, UsageKind::VirtMem) - .ok()?, - ); - if zeroed { - unsafe { core::ptr::write_bytes(vaddr.as_mut_ptr(), 0, PAGE_SIZE_4K) }; - } - let paddr = virt_to_phys(vaddr); - Some(paddr) -} - -fn dealloc_frame(frame: PhysAddr) { - let vaddr = phys_to_virt(frame); - global_allocator().dealloc_pages(vaddr.as_usize(), 1, UsageKind::VirtMem); -} - -impl Backend { - /// Creates a new allocation mapping backend. - pub const fn new_alloc(populate: bool) -> Self { - Self::Alloc { populate } - } - - pub(crate) fn map_alloc( - &self, - start: VirtAddr, - size: usize, - flags: MappingFlags, - pt: &mut PageTable, - populate: bool, - ) -> bool { - debug!( - "map_alloc: [{:#x}, {:#x}) {:?} (populate={})", - start, - start + size, - flags, - populate - ); - if populate { - // allocate all possible physical frames for populated mapping. - for addr in PageIter4K::new(start, start + size).unwrap() { - if let Some(frame) = alloc_frame(true) { - if let Ok(tlb) = pt.map(addr, frame, PageSize::Size4K, flags) { - tlb.ignore(); // TLB flush on map is unnecessary, as there are no outdated mappings. - } else { - return false; - } - } - } - true - } else { - // Map to a empty entry for on-demand mapping. - let flags = MappingFlags::empty(); - pt.map_region(start, |_| 0.into(), size, flags, false, false) - .map(|tlb| tlb.ignore()) - .is_ok() - } - } - - pub(crate) fn unmap_alloc( - &self, - start: VirtAddr, - size: usize, - pt: &mut PageTable, - _populate: bool, - ) -> bool { - debug!("unmap_alloc: [{:#x}, {:#x})", start, start + size); - for addr in PageIter4K::new(start, start + size).unwrap() { - if let Ok((frame, page_size, tlb)) = pt.unmap(addr) { - // Deallocate the physical frame if there is a mapping in the - // page table. - if page_size.is_huge() { - return false; - } - tlb.flush(); - dealloc_frame(frame); - } else { - // Deallocation is needn't if the page is not mapped. - } - } - true - } - - pub(crate) fn handle_page_fault_alloc( - &self, - vaddr: VirtAddr, - orig_flags: MappingFlags, - pt: &mut PageTable, - populate: bool, - ) -> bool { - if populate { - false // Populated mappings should not trigger page faults. - } else if let Some(frame) = alloc_frame(true) { - // Allocate a physical frame lazily and map it to the fault address. - // `vaddr` does not need to be aligned. It will be automatically - // aligned during `pt.remap` regardless of the page size. - pt.remap(vaddr, frame, orig_flags) - .map(|(_, tlb)| tlb.flush()) - .is_ok() - } else { - false - } - } -} diff --git a/modules/axmm/src/backend/cow.rs b/modules/axmm/src/backend/cow.rs new file mode 100644 index 0000000000..a4f9862de1 --- /dev/null +++ b/modules/axmm/src/backend/cow.rs @@ -0,0 +1,231 @@ +use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc}; +use core::slice; + +use axerrno::{AxError, AxResult}; +use axfs::FileBackend; +use axhal::{ + mem::phys_to_virt, + paging::{MappingFlags, PageSize, PageTableMut, PagingError}, +}; +use axsync::Mutex; +use kspin::SpinNoIrq; +use memory_addr::{PhysAddr, VirtAddr, VirtAddrRange}; + +use crate::{ + AddrSpace, + backend::{Backend, BackendOps, alloc_frame, dealloc_frame, pages_in}, +}; + +static FRAME_TABLE: SpinNoIrq> = SpinNoIrq::new(BTreeMap::new()); + +fn inc_frame_ref(paddr: PhysAddr) { + let mut table = FRAME_TABLE.lock(); + *table.entry(paddr).or_insert(0) += 1; +} + +fn dec_frame_ref(paddr: PhysAddr) -> usize { + let mut table = FRAME_TABLE.lock(); + if let Some(count) = table.get_mut(&paddr) { + let prev = *count; + if prev == 1 { + table.remove(&paddr); + } else { + *count -= 1; + } + prev as usize + } else { + 0 + } +} + +/// Copy-on-write mapping backend. +/// +/// This corresponds to the `MAP_PRIVATE` flag. +#[derive(Clone)] +pub struct CowBackend { + start: VirtAddr, + size: PageSize, + file: Option<(FileBackend, u64, Option)>, +} + +impl CowBackend { + fn alloc_new_at( + &self, + vaddr: VirtAddr, + flags: MappingFlags, + pt: &mut PageTableMut, + ) -> AxResult { + let frame = alloc_frame(true, self.size)?; + inc_frame_ref(frame); + + if let Some((file, file_start, file_end)) = &self.file { + let buf = unsafe { + slice::from_raw_parts_mut(phys_to_virt(frame).as_mut_ptr(), self.size as _) + }; + // vaddr can be smaller than self.start (at most 1 page) due to + // non-aligned mappings, we need to keep the gap clean. + let start = self.start.as_usize().saturating_sub(vaddr.as_usize()); + assert!(start < self.size as _); + + let file_start = + *file_start + vaddr.as_usize().saturating_sub(self.start.as_usize()) as u64; + let max_read = file_end + .map_or(u64::MAX, |end| end.saturating_sub(file_start)) + .min((buf.len() - start) as u64) as usize; + + file.read_at(&mut &mut buf[start..start + max_read], file_start)?; + } + pt.map(vaddr, frame, self.size, flags)?; + Ok(()) + } + + fn handle_cow_fault( + &self, + vaddr: VirtAddr, + paddr: PhysAddr, + flags: MappingFlags, + pt: &mut PageTableMut, + ) -> AxResult { + match dec_frame_ref(paddr) { + 0 => unreachable!(), + // There is only one AddrSpace reference to the page, + // so there is no need to copy it. + 1 => { + inc_frame_ref(paddr); + pt.protect(vaddr, flags)?; + } + // Allocates the new page and copies the contents of the original page, + // remapping the virtual address to the physical address of the new page. + 2.. => { + let new_frame = alloc_frame(false, self.size)?; + inc_frame_ref(new_frame); + unsafe { + core::ptr::copy_nonoverlapping( + phys_to_virt(paddr).as_ptr(), + phys_to_virt(new_frame).as_mut_ptr(), + self.size as _, + ); + } + + pt.remap(vaddr, new_frame, flags)?; + } + } + + Ok(()) + } +} + +impl BackendOps for CowBackend { + fn page_size(&self) -> PageSize { + self.size + } + + fn map(&self, range: VirtAddrRange, flags: MappingFlags, _pt: &mut PageTableMut) -> AxResult { + debug!("Cow::map: {range:?} {flags:?}",); + Ok(()) + } + + fn unmap(&self, range: VirtAddrRange, pt: &mut PageTableMut) -> AxResult { + debug!("Cow::unmap: {range:?}"); + for addr in pages_in(range, self.size)? { + if let Ok((frame, _flags, page_size)) = pt.unmap(addr) { + assert_eq!(page_size, self.size); + if dec_frame_ref(frame) == 1 { + dealloc_frame(frame, self.size); + } + } else { + // Deallocation is needn't if the page is not allocated. + } + } + Ok(()) + } + + fn populate( + &self, + range: VirtAddrRange, + flags: MappingFlags, + access_flags: MappingFlags, + pt: &mut PageTableMut, + ) -> AxResult<(usize, Option>)> { + let mut pages = 0; + for addr in pages_in(range, self.size)? { + match pt.query(addr) { + Ok((paddr, page_flags, page_size)) => { + assert_eq!(self.size, page_size); + if access_flags.contains(MappingFlags::WRITE) + && !page_flags.contains(MappingFlags::WRITE) + { + self.handle_cow_fault(addr, paddr, flags, pt)?; + pages += 1; + } else if page_flags.contains(access_flags) { + pages += 1; + } + } + // If the page is not mapped, try map it. + Err(PagingError::NotMapped) => { + self.alloc_new_at(addr, flags, pt)?; + pages += 1; + } + Err(_) => return Err(AxError::BadAddress), + } + } + Ok((pages, None)) + } + + fn clone_map( + &self, + range: VirtAddrRange, + flags: MappingFlags, + old_pt: &mut PageTableMut, + new_pt: &mut PageTableMut, + _new_aspace: &Arc>, + ) -> AxResult { + let cow_flags = flags - MappingFlags::WRITE; + + for vaddr in pages_in(range, self.size)? { + // Copy data from old memory area to new memory area. + match old_pt.query(vaddr) { + Ok((paddr, _, page_size)) => { + assert_eq!(page_size, self.size); + // If the page is mapped in the old page table: + // - Update its permissions in the old page table using `flags`. + // - Map the same physical page into the new page table at the same + // virtual address, with the same page size and `flags`. + inc_frame_ref(paddr); + + old_pt.protect(vaddr, cow_flags)?; + new_pt.map(vaddr, paddr, self.size, cow_flags)?; + } + // If the page is not mapped, skip it. + Err(PagingError::NotMapped) => {} + Err(_) => return Err(AxError::BadAddress), + }; + } + + Ok(Backend::Cow(self.clone())) + } +} + +impl Backend { + pub fn new_cow( + start: VirtAddr, + size: PageSize, + file: FileBackend, + file_start: u64, + file_end: Option, + ) -> Self { + Self::Cow(CowBackend { + start, + size, + file: Some((file, file_start, file_end)), + }) + } + + pub fn new_alloc(start: VirtAddr, size: PageSize) -> Self { + Self::Cow(CowBackend { + start, + size, + file: None, + }) + } +} diff --git a/modules/axmm/src/backend/file.rs b/modules/axmm/src/backend/file.rs new file mode 100644 index 0000000000..7344c32b58 --- /dev/null +++ b/modules/axmm/src/backend/file.rs @@ -0,0 +1,247 @@ +use alloc::{ + boxed::Box, + sync::{Arc, Weak}, + vec::Vec, +}; +use core::sync::atomic::{AtomicUsize, Ordering}; + +use axerrno::{AxError, AxResult}; +use axfs::{CachedFile, FileFlags}; +use axhal::paging::{MappingFlags, PageSize, PageTableMut, PagingError}; +use axsync::Mutex; +use memory_addr::{PAGE_SIZE_4K, VirtAddr, VirtAddrRange}; + +use crate::{ + AddrSpace, + backend::{Backend, BackendOps, pages_in}, +}; + +#[doc(hidden)] +pub struct FileBackendInner { + start: VirtAddr, + cache: CachedFile, + flags: FileFlags, + offset_page: u32, + handle: AtomicUsize, + futex_handle: Arc<()>, +} +impl Drop for FileBackendInner { + fn drop(&mut self) { + let handle = self.handle.load(Ordering::Acquire); + if handle != 0 { + unsafe { + self.cache.remove_evict_listener(handle); + } + } + } +} +impl FileBackendInner { + pub fn register_listener(self: &Arc, aspace: &Arc>) -> usize { + let aspace = Arc::downgrade(aspace); + self.cache.add_evict_listener({ + let this = Arc::downgrade(self); + move |pn, _page| { + let Some(this) = this.upgrade() else { + return; + }; + let Some(aspace) = aspace.upgrade() else { + // The address space has been dropped, nothing to do. + return; + }; + let Some(mut aspace) = aspace.try_lock() else { + // This can happen during the populate process, when new pages + // are being populated and old pages are being evicted. In this + // case, we delegate the unmapping to the populate process. + return; + }; + this.on_evict(pn, &mut aspace); + } + }) + } + + fn on_evict(self: &Arc, pn: u32, aspace: &mut AddrSpace) { + let Some(pn) = pn.checked_sub(self.offset_page) else { + return; + }; + let vaddr = self.start + pn as usize * PageSize::Size4K as usize; + if !aspace.find_area(vaddr).is_some_and( + |it| matches!(it.backend(), Backend::File(file) if Arc::ptr_eq(&file.0, self)), + ) { + // Ignore if the page is not controlled by this file mapping. + return; + } + + let pt = aspace.page_table_mut(); + match pt.modify().unmap(vaddr) { + Ok(_) | Err(PagingError::NotMapped) => {} + Err(err) => { + warn!("Failed to unmap page {:?}: {:?}", vaddr, err); + } + } + } +} + +/// File-backed mapping backend. +#[derive(Clone)] +pub struct FileBackend(Arc); +impl FileBackend { + fn check_flags(&self, flags: MappingFlags) -> AxResult { + let mut required_flags = FileFlags::empty(); + if flags.contains(MappingFlags::READ) { + required_flags |= FileFlags::READ; + } + if flags.contains(MappingFlags::WRITE) { + required_flags |= FileFlags::WRITE; + } + + if !self.0.flags.contains(required_flags) { + return Err(AxError::PermissionDenied); + } + Ok(()) + } + + pub fn futex_handle(&self) -> Weak<()> { + Arc::downgrade(&self.0.futex_handle) + } +} + +impl BackendOps for FileBackend { + fn page_size(&self) -> PageSize { + PageSize::Size4K + } + + fn map(&self, _range: VirtAddrRange, flags: MappingFlags, _pt: &mut PageTableMut) -> AxResult { + self.check_flags(flags) + } + + fn unmap(&self, range: VirtAddrRange, pt: &mut PageTableMut) -> AxResult { + for addr in pages_in(range, PageSize::Size4K)? { + match pt.unmap(addr) { + Ok(_) | Err(PagingError::NotMapped) => {} + Err(err) => { + warn!("Failed to unmap page {:?}: {:?}", addr, err); + return Err(err.into()); + } + } + } + Ok(()) + } + + fn on_protect( + &self, + _range: VirtAddrRange, + new_flags: MappingFlags, + _pt: &mut PageTableMut, + ) -> AxResult { + self.check_flags(new_flags) + } + + fn populate( + &self, + range: VirtAddrRange, + flags: MappingFlags, + access_flags: MappingFlags, + pt: &mut PageTableMut, + ) -> AxResult<(usize, Option>)> { + let mut pages = 0; + let mut to_be_evicted = Vec::new(); + let start_page = ((range.start - self.0.start) / PAGE_SIZE_4K) as u32 + self.0.offset_page; + for (i, addr) in pages_in(range, PageSize::Size4K)?.enumerate() { + let pn = start_page + i as u32; + match pt.query(addr) { + Ok((paddr, page_flags, _)) => { + if access_flags.contains(MappingFlags::WRITE) + && !page_flags.contains(MappingFlags::WRITE) + { + let in_memory = self.0.cache.in_memory(); + self.0.cache.with_page(pn, |page| { + if !in_memory { + page.expect("page should be present").mark_dirty(); + } + pt.remap(addr, paddr, flags)?; + pages += 1; + AxResult::Ok(()) + })?; + } else if page_flags.contains(access_flags) { + pages += 1; + } + } + // If the page is not mapped, try map it. + Err(PagingError::NotMapped) => { + let map_flags = if self.0.cache.in_memory() { + // For in memory files, we don't need to (and also + // musn't) mark them dirty, so we can use the original + // flags. + flags + } else { + flags - MappingFlags::WRITE + }; + self.0.cache.with_page_or_insert(pn, |page, evicted| { + if let Some((pn, _)) = evicted { + to_be_evicted.push(pn); + } + pt.map(addr, page.paddr(), PageSize::Size4K, map_flags)?; + pages += 1; + Ok(()) + })?; + } + Err(_) => return Err(AxError::BadAddress), + } + } + Ok(( + pages, + if to_be_evicted.is_empty() { + None + } else { + let inner = self.0.clone(); + Some(Box::new(move |aspace: &mut AddrSpace| { + for pn in to_be_evicted { + inner.on_evict(pn, aspace); + } + })) + }, + )) + } + + fn clone_map( + &self, + _range: VirtAddrRange, + _flags: MappingFlags, + _old_pt: &mut PageTableMut, + _new_pt: &mut PageTableMut, + new_aspace: &Arc>, + ) -> AxResult { + let inner = Arc::new(FileBackendInner { + start: self.0.start, + cache: self.0.cache.clone(), + flags: self.0.flags, + offset_page: self.0.offset_page, + handle: AtomicUsize::new(0), + futex_handle: self.0.futex_handle.clone(), + }); + inner.register_listener(new_aspace); + Ok(Backend::File(FileBackend(inner))) + } +} + +impl Backend { + pub fn new_file( + start: VirtAddr, + cache: CachedFile, + flags: FileFlags, + offset: usize, + aspace: &Arc>, + ) -> Self { + let offset_page = (offset / PAGE_SIZE_4K) as u32; + let inner = Arc::new(FileBackendInner { + start, + cache, + flags, + offset_page, + handle: AtomicUsize::new(0), + futex_handle: Arc::new(()), + }); + inner.register_listener(aspace); + Self::File(FileBackend(inner)) + } +} diff --git a/modules/axmm/src/backend/linear.rs b/modules/axmm/src/backend/linear.rs index 3807766590..37bddc5fe9 100644 --- a/modules/axmm/src/backend/linear.rs +++ b/modules/axmm/src/backend/linear.rs @@ -1,46 +1,64 @@ -use axhal::paging::{MappingFlags, PageTable}; -use memory_addr::{PhysAddr, VirtAddr}; +use alloc::sync::Arc; -use super::Backend; +use axerrno::AxResult; +use axhal::paging::{MappingFlags, PageSize, PageTableMut}; +use axsync::Mutex; +use memory_addr::{PhysAddr, PhysAddrRange, VirtAddr, VirtAddrRange}; -impl Backend { - /// Creates a new linear mapping backend. - pub const fn new_linear(pa_va_offset: usize) -> Self { - Self::Linear { pa_va_offset } +use crate::{ + AddrSpace, + backend::{Backend, BackendOps}, +}; + +/// Linear mapping backend. +/// +/// The offset between the virtual address and the physical address is +/// constant, which is specified by `pa_va_offset`. For example, the virtual +/// address `vaddr` is mapped to the physical address `vaddr - pa_va_offset`. +#[derive(Clone)] +pub struct LinearBackend { + offset: isize, +} + +impl LinearBackend { + fn pa(&self, va: VirtAddr) -> PhysAddr { + PhysAddr::from((va.as_usize() as isize - self.offset) as usize) } +} - pub(crate) fn map_linear( - &self, - start: VirtAddr, - size: usize, - flags: MappingFlags, - pt: &mut PageTable, - pa_va_offset: usize, - ) -> bool { - let va_to_pa = |va: VirtAddr| PhysAddr::from(va.as_usize() - pa_va_offset); - debug!( - "map_linear: [{:#x}, {:#x}) -> [{:#x}, {:#x}) {:?}", - start, - start + size, - va_to_pa(start), - va_to_pa(start + size), - flags - ); - pt.map_region(start, va_to_pa, size, flags, false, false) - .map(|tlb| tlb.ignore()) // TLB flush on map is unnecessary, as there are no outdated mappings. - .is_ok() +impl BackendOps for LinearBackend { + fn page_size(&self) -> PageSize { + PageSize::Size4K + } + + fn map(&self, range: VirtAddrRange, flags: MappingFlags, pt: &mut PageTableMut) -> AxResult { + let pa_range = PhysAddrRange::from_start_size(self.pa(range.start), range.size()); + debug!("Linear::map: {range:?} -> {pa_range:?} {flags:?}"); + pt.map_region(range.start, |va| self.pa(va), range.size(), flags, false)?; + Ok(()) + } + + fn unmap(&self, range: VirtAddrRange, pt: &mut PageTableMut) -> AxResult { + let pa_range = PhysAddrRange::from_start_size(self.pa(range.start), range.size()); + debug!("Linear::unmap: {range:?} -> {pa_range:?}"); + pt.unmap_region(range.start, range.size())?; + Ok(()) } - pub(crate) fn unmap_linear( + fn clone_map( &self, - start: VirtAddr, - size: usize, - pt: &mut PageTable, - _pa_va_offset: usize, - ) -> bool { - debug!("unmap_linear: [{:#x}, {:#x})", start, start + size); - pt.unmap_region(start, size, true) - .map(|tlb| tlb.ignore()) // flush each page on unmap, do not flush the entire TLB. - .is_ok() + _range: VirtAddrRange, + _flags: MappingFlags, + _old_pt: &mut PageTableMut, + _new_pt: &mut PageTableMut, + _new_aspace: &Arc>, + ) -> AxResult { + Ok(Backend::Linear(self.clone())) + } +} + +impl Backend { + pub fn new_linear(offset: isize) -> Self { + Self::Linear(LinearBackend { offset }) } } diff --git a/modules/axmm/src/backend/mod.rs b/modules/axmm/src/backend/mod.rs index 5df33a8908..0bc1760564 100644 --- a/modules/axmm/src/backend/mod.rs +++ b/modules/axmm/src/backend/mod.rs @@ -1,58 +1,139 @@ //! Memory mapping backends. +use alloc::{boxed::Box, sync::Arc}; -use axhal::paging::{MappingFlags, PageTable}; -use memory_addr::VirtAddr; +use axalloc::{UsageKind, global_allocator}; +use axerrno::{AxError, AxResult}; +use axhal::{ + mem::{phys_to_virt, virt_to_phys}, + paging::{MappingFlags, PageSize, PageTable, PageTableMut}, +}; +use axsync::Mutex; +use enum_dispatch::enum_dispatch; +use memory_addr::{DynPageIter, PAGE_SIZE_4K, PhysAddr, VirtAddr, VirtAddrRange}; use memory_set::MappingBackend; -mod alloc; -mod linear; +pub mod cow; +pub mod file; +pub mod linear; +pub mod shared; + +pub use shared::SharedPages; + +use crate::AddrSpace; + +fn divide_page(size: usize, page_size: PageSize) -> usize { + assert!(page_size.is_aligned(size), "unaligned"); + size >> (page_size as usize).trailing_zeros() +} + +fn alloc_frame(zeroed: bool, size: PageSize) -> AxResult { + let page_size = size as usize; + let num_pages = page_size / PAGE_SIZE_4K; + let vaddr = + VirtAddr::from(global_allocator().alloc_pages(num_pages, page_size, UsageKind::VirtMem)?); + if zeroed { + unsafe { core::ptr::write_bytes(vaddr.as_mut_ptr(), 0, page_size) }; + } + let paddr = virt_to_phys(vaddr); + + Ok(paddr) +} + +fn dealloc_frame(frame: PhysAddr, align: PageSize) { + let vaddr = phys_to_virt(frame); + let page_size: usize = align.into(); + let num_pages = page_size / PAGE_SIZE_4K; + global_allocator().dealloc_pages(vaddr.as_usize(), num_pages, UsageKind::VirtMem); +} + +fn pages_in(range: VirtAddrRange, align: PageSize) -> AxResult> { + DynPageIter::new(range.start, range.end, align as usize).ok_or(AxError::InvalidInput) +} + +#[enum_dispatch] +pub trait BackendOps { + /// Returns the page size of the backend. + fn page_size(&self) -> PageSize; + + /// Map a memory region. + fn map(&self, range: VirtAddrRange, flags: MappingFlags, pt: &mut PageTableMut) -> AxResult; + + /// Unmap a memory region. + fn unmap(&self, range: VirtAddrRange, pt: &mut PageTableMut) -> AxResult; + + /// Called before a memory region is protected. + fn on_protect( + &self, + _range: VirtAddrRange, + _new_flags: MappingFlags, + _pt: &mut PageTableMut, + ) -> AxResult { + Ok(()) + } + + /// Populate a memory region and return how many pages now satisfy + /// `access_flags`. + /// + /// If another thread has already mapped the page with sufficient permissions, + /// treat it as populated. + fn populate( + &self, + _range: VirtAddrRange, + _flags: MappingFlags, + _access_flags: MappingFlags, + _pt: &mut PageTableMut, + ) -> AxResult<(usize, Option>)> { + Ok((0, None)) + } + + /// Duplicates this mapping for use in a different page table. + /// + /// This differs from `clone`, which is designed for splitting a mapping + /// within the same table. + /// + /// [`BackendOps::map`] will be latter called to the returned backend. + fn clone_map( + &self, + range: VirtAddrRange, + flags: MappingFlags, + old_pt: &mut PageTableMut, + new_pt: &mut PageTableMut, + new_aspace: &Arc>, + ) -> AxResult; +} /// A unified enum type for different memory mapping backends. -/// -/// Currently, two backends are implemented: -/// -/// - **Linear**: used for linear mappings. The target physical frames are -/// contiguous and their addresses should be known when creating the mapping. -/// - **Allocation**: used in general, or for lazy mappings. The target physical -/// frames are obtained from the global allocator. #[derive(Clone)] +#[enum_dispatch(BackendOps)] pub enum Backend { - /// Linear mapping backend. - /// - /// The offset between the virtual address and the physical address is - /// constant, which is specified by `pa_va_offset`. For example, the virtual - /// address `vaddr` is mapped to the physical address `vaddr - pa_va_offset`. - Linear { - /// `vaddr - paddr`. - pa_va_offset: usize, - }, - /// Allocation mapping backend. - /// - /// If `populate` is `true`, all physical frames are allocated when the - /// mapping is created, and no page faults are triggered during the memory - /// access. Otherwise, the physical frames are allocated on demand (by - /// handling page faults). - Alloc { - /// Whether to populate the physical frames when creating the mapping. - populate: bool, - }, + Linear(linear::LinearBackend), + Cow(cow::CowBackend), + Shared(shared::SharedBackend), + File(file::FileBackend), } impl MappingBackend for Backend { type Addr = VirtAddr; type Flags = MappingFlags; type PageTable = PageTable; + fn map(&self, start: VirtAddr, size: usize, flags: MappingFlags, pt: &mut PageTable) -> bool { - match *self { - Self::Linear { pa_va_offset } => self.map_linear(start, size, flags, pt, pa_va_offset), - Self::Alloc { populate } => self.map_alloc(start, size, flags, pt, populate), + let range = VirtAddrRange::from_start_size(start, size); + if let Err(err) = BackendOps::map(self, range, flags, &mut pt.modify()) { + warn!("Failed to map area: {:?}", err); + false + } else { + true } } fn unmap(&self, start: VirtAddr, size: usize, pt: &mut PageTable) -> bool { - match *self { - Self::Linear { pa_va_offset } => self.unmap_linear(start, size, pt, pa_va_offset), - Self::Alloc { populate } => self.unmap_alloc(start, size, pt, populate), + let range = VirtAddrRange::from_start_size(start, size); + if let Err(err) = BackendOps::unmap(self, range, &mut pt.modify()) { + warn!("Failed to unmap area: {:?}", err); + false + } else { + true } } @@ -61,27 +142,8 @@ impl MappingBackend for Backend { start: Self::Addr, size: usize, new_flags: Self::Flags, - page_table: &mut Self::PageTable, + pt: &mut Self::PageTable, ) -> bool { - page_table - .protect_region(start, size, new_flags, true) - .map(|tlb| tlb.ignore()) - .is_ok() - } -} - -impl Backend { - pub(crate) fn handle_page_fault( - &self, - vaddr: VirtAddr, - orig_flags: MappingFlags, - page_table: &mut PageTable, - ) -> bool { - match *self { - Self::Linear { .. } => false, // Linear mappings should not trigger page faults. - Self::Alloc { populate } => { - self.handle_page_fault_alloc(vaddr, orig_flags, page_table, populate) - } - } + pt.modify().protect_region(start, size, new_flags).is_ok() } } diff --git a/modules/axmm/src/backend/shared.rs b/modules/axmm/src/backend/shared.rs new file mode 100644 index 0000000000..6486e9ed17 --- /dev/null +++ b/modules/axmm/src/backend/shared.rs @@ -0,0 +1,111 @@ +use alloc::{sync::Arc, vec::Vec}; +use core::ops::Deref; + +use axerrno::AxResult; +use axhal::paging::{MappingFlags, PageSize, PageTableMut}; +use axsync::Mutex; +use memory_addr::{MemoryAddr, PhysAddr, VirtAddr, VirtAddrRange}; + +use super::{alloc_frame, dealloc_frame}; +use crate::{ + AddrSpace, + backend::{Backend, BackendOps, divide_page, pages_in}, +}; + +pub struct SharedPages { + pub phys_pages: Vec, + pub size: PageSize, +} +impl SharedPages { + pub fn new(size: usize, page_size: PageSize) -> AxResult { + Ok(Self { + phys_pages: (0..divide_page(size, page_size)) + .map(|_| alloc_frame(true, page_size)) + .collect::>()?, + size: page_size, + }) + } + + pub fn len(&self) -> usize { + self.phys_pages.len() + } + + pub fn is_empty(&self) -> bool { + self.phys_pages.is_empty() + } +} + +impl Deref for SharedPages { + type Target = [PhysAddr]; + + fn deref(&self) -> &Self::Target { + &self.phys_pages + } +} + +impl Drop for SharedPages { + fn drop(&mut self) { + for frame in &self.phys_pages { + dealloc_frame(*frame, self.size); + } + } +} + +// FIXME: This implementation does not allow map or unmap partial ranges. +#[derive(Clone)] +pub struct SharedBackend { + start: VirtAddr, + pages: Arc, +} +impl SharedBackend { + pub fn pages(&self) -> &Arc { + &self.pages + } + + fn pages_starting_from(&self, start: VirtAddr) -> &[PhysAddr] { + debug_assert!(start.is_aligned(self.pages.size)); + let start_index = divide_page(start - self.start, self.pages.size); + &self.pages[start_index..] + } +} + +impl BackendOps for SharedBackend { + fn page_size(&self) -> PageSize { + self.pages.size + } + + fn map(&self, range: VirtAddrRange, flags: MappingFlags, pt: &mut PageTableMut) -> AxResult { + debug!("Shared::map: {:?} {:?}", range, flags); + for (vaddr, paddr) in + pages_in(range, self.pages.size)?.zip(self.pages_starting_from(range.start)) + { + pt.map(vaddr, *paddr, self.pages.size, flags)?; + } + Ok(()) + } + + fn unmap(&self, range: VirtAddrRange, pt: &mut PageTableMut) -> AxResult { + debug!("Shared::unmap: {:?}", range); + for vaddr in pages_in(range, self.pages.size)? { + pt.unmap(vaddr)?; + } + Ok(()) + } + + fn clone_map( + &self, + _range: VirtAddrRange, + _flags: MappingFlags, + _old_pt: &mut PageTableMut, + _new_pt: &mut PageTableMut, + _new_aspace: &Arc>, + ) -> AxResult { + Ok(Backend::Shared(self.clone())) + } +} + +impl Backend { + pub fn new_shared(start: VirtAddr, pages: Arc) -> Self { + Self::Shared(SharedBackend { start, pages }) + } +} diff --git a/modules/axmm/src/lib.rs b/modules/axmm/src/lib.rs index 301a1028e9..001d722019 100644 --- a/modules/axmm/src/lib.rs +++ b/modules/axmm/src/lib.rs @@ -4,20 +4,22 @@ #[macro_use] extern crate log; + extern crate alloc; mod aspace; -mod backend; - -pub use self::aspace::AddrSpace; -pub use self::backend::Backend; +pub mod backend; -use axerrno::{AxError, AxResult}; -use axhal::mem::{MemRegionFlags, phys_to_virt}; -use axhal::paging::MappingFlags; +use axerrno::LinuxResult; +use axhal::{ + mem::{MemRegionFlags, phys_to_virt}, + paging::MappingFlags, +}; use kspin::SpinNoIrq; use lazyinit::LazyInit; -use memory_addr::{MemoryAddr, PhysAddr, VirtAddr}; +use memory_addr::{MemoryAddr, PhysAddr, va}; + +pub use self::aspace::AddrSpace; static KERNEL_ASPACE: LazyInit> = LazyInit::new(); @@ -42,9 +44,11 @@ fn reg_flag_to_map_flag(f: MemRegionFlags) -> MappingFlags { } /// Creates a new address space for kernel itself. -pub fn new_kernel_aspace() -> AxResult { - let (base, size) = axhal::mem::kernel_aspace(); - let mut aspace = AddrSpace::new_empty(base, size)?; +pub fn new_kernel_aspace() -> LinuxResult { + let mut aspace = AddrSpace::new_empty( + va!(axconfig::plat::KERNEL_ASPACE_BASE), + axconfig::plat::KERNEL_ASPACE_SIZE, + )?; for r in axhal::mem::memory_regions() { // mapped range should contain the whole region if it is not aligned. let start = r.paddr.align_down_4k(); @@ -77,35 +81,16 @@ pub fn init_memory_management() { info!("Initialize virtual memory management..."); let kernel_aspace = new_kernel_aspace().expect("failed to initialize kernel address space"); - debug!("kernel address space init OK: {kernel_aspace:#x?}"); + debug!("kernel address space init OK: {:#x?}", kernel_aspace); KERNEL_ASPACE.init_once(SpinNoIrq::new(kernel_aspace)); unsafe { axhal::asm::write_kernel_page_table(kernel_page_table_root()) }; + // flush all TLB + axhal::asm::flush_tlb(None); } /// Initializes kernel paging for secondary CPUs. pub fn init_memory_management_secondary() { unsafe { axhal::asm::write_kernel_page_table(kernel_page_table_root()) }; -} - -/// Maps a physical memory region to virtual address space for device access. -pub fn iomap(addr: PhysAddr, size: usize) -> AxResult { - let virt = phys_to_virt(addr); - - let virt_aligned = virt.align_down_4k(); - let addr_aligned = addr.align_down_4k(); - let size_aligned = (addr + size).align_up_4k() - addr_aligned; - - let flags = MappingFlags::DEVICE | MappingFlags::READ | MappingFlags::WRITE; - let mut tb = kernel_aspace().lock(); - match tb.map_linear(virt_aligned, addr_aligned, size_aligned, flags) { - Err(AxError::AlreadyExists) => {} - Err(e) => { - return Err(e); - } - Ok(_) => {} - } - // flush TLB - // FIXME: remove this - tb.protect(virt_aligned, size_aligned, flags)?; - Ok(virt) + // flush all TLB + axhal::asm::flush_tlb(None); } diff --git a/modules/axnet/Cargo.toml b/modules/axnet/Cargo.toml index c8d0e8b828..6a2aacdee3 100644 --- a/modules/axnet/Cargo.toml +++ b/modules/axnet/Cargo.toml @@ -3,8 +3,8 @@ name = "axnet" version.workspace = true edition.workspace = true authors = [ - "Yuekai Jia ", - "ChengXiang Qi ", + "Yuekai Jia ", + "ChengXiang Qi ", ] description = "ArceOS network module" license.workspace = true @@ -13,39 +13,51 @@ repository = "https://github.com/arceos-org/arceos/tree/main/modules/axnet" documentation = "https://arceos-org.github.io/arceos/axnet/index.html" [features] -smoltcp = ["dep:smoltcp"] -default = ["smoltcp"] vsock = ["axdriver/vsock"] [dependencies] +axconfig = { workspace = true } axdriver = { workspace = true, features = ["net"] } -axerrno.workspace = true -axhal.workspace = true -axio.workspace = true -axsync.workspace = true -axtask.workspace = true -cfg-if.workspace = true -lazyinit.workspace = true -log.workspace = true -spin.workspace = true +axhal = { workspace = true } +axsync = { workspace = true } +axtask = { workspace = true } + +axerrno = { workspace = true } +axfs = { workspace = true } +axfs-ng-vfs = { workspace = true } +axio = { workspace = true } +axpoll = { workspace = true } +bitflags = "2.9.1" +cfg-if = { workspace = true } +enum_dispatch = { workspace = true } +hashbrown = "0.16" +lazyinit = { workspace = true } +lazy_static = { workspace = true } +log = { workspace = true } +spin = { workspace = true } +ringbuf = { version = "0.4.8", default-features = false, features = ["alloc"] } +async-channel = { version = "2.5.0", default-features = false } +event-listener = { version = "5.4.0", default-features = false } +async-trait = "0.1.88" [dependencies.smoltcp] -git = "https://github.com/rcore-os/smoltcp.git" -rev = "21a2f82" +git = "https://github.com/Starry-OS/smoltcp.git" +rev = "7401a54" default-features = false -optional = true features = [ - "alloc", - "log", # no std - "medium-ethernet", - "proto-ipv4", - "proto-ipv6", - "socket-raw", - "socket-icmp", - "socket-udp", - "socket-tcp", - "socket-dns", - # "fragmentation-buffer-size-65536", "proto-ipv4-fragmentation", - # "reassembly-buffer-size-65536", "reassembly-buffer-count-32", - # "assembler-max-segment-count-32", + "alloc", + "log", # no std + "async", + "medium-ethernet", + "medium-ip", + "proto-ipv4", + "proto-ipv6", + "socket-raw", + "socket-icmp", + "socket-udp", + "socket-tcp", + "socket-dns", + # "fragmentation-buffer-size-65536", "proto-ipv4-fragmentation", + # "reassembly-buffer-size-65536", "reassembly-buffer-count-32", + # "assembler-max-segment-count-32", ] diff --git a/modules/axnet/src/consts.rs b/modules/axnet/src/consts.rs new file mode 100644 index 0000000000..d229c5eb8b --- /dev/null +++ b/modules/axnet/src/consts.rs @@ -0,0 +1,23 @@ +macro_rules! env_or_default { + ($key:literal) => { + match option_env!($key) { + Some(val) => val, + None => "", + } + }; +} + +pub const IP: &str = env_or_default!("AX_IP"); +pub const GATEWAY: &str = env_or_default!("AX_GW"); +pub const IP_PREFIX: u8 = 24; + +pub const STANDARD_MTU: usize = 1500; + +pub const TCP_RX_BUF_LEN: usize = 64 * 1024; +pub const TCP_TX_BUF_LEN: usize = 64 * 1024; +pub const UDP_RX_BUF_LEN: usize = 64 * 1024; +pub const UDP_TX_BUF_LEN: usize = 64 * 1024; +pub const LISTEN_QUEUE_SIZE: usize = 512; + +pub const SOCKET_BUFFER_SIZE: usize = 64; +pub const ETHERNET_MAX_PENDING_PACKETS: usize = 32; diff --git a/modules/axnet/src/device/ethernet.rs b/modules/axnet/src/device/ethernet.rs new file mode 100644 index 0000000000..359d186bfa --- /dev/null +++ b/modules/axnet/src/device/ethernet.rs @@ -0,0 +1,341 @@ +use alloc::{string::String, vec}; +use core::task::Waker; + +use axdriver::prelude::*; +use axtask::future::register_irq_waker; +use hashbrown::HashMap; +use smoltcp::{ + storage::{PacketBuffer, PacketMetadata}, + time::{Duration, Instant}, + wire::{ + ArpOperation, ArpPacket, ArpRepr, EthernetAddress, EthernetFrame, EthernetProtocol, + EthernetRepr, IpAddress, Ipv4Cidr, + }, +}; + +use crate::{ + consts::{ETHERNET_MAX_PENDING_PACKETS, STANDARD_MTU}, + device::Device, +}; + +const EMPTY_MAC: EthernetAddress = EthernetAddress([0; 6]); + +struct Neighbor { + hardware_address: EthernetAddress, + expires_at: Instant, +} + +pub struct EthernetDevice { + name: String, + inner: AxNetDevice, + neighbors: HashMap>, + ip: Ipv4Cidr, + + pending_packets: PacketBuffer<'static, IpAddress>, +} +impl EthernetDevice { + const NEIGHBOR_TTL: Duration = Duration::from_secs(60); + + pub fn new(name: String, inner: AxNetDevice, ip: Ipv4Cidr) -> Self { + let pending_packets = PacketBuffer::new( + vec![PacketMetadata::EMPTY; ETHERNET_MAX_PENDING_PACKETS], + vec![ + 0u8; + (STANDARD_MTU + EthernetFrame::<&[u8]>::header_len()) + * ETHERNET_MAX_PENDING_PACKETS + ], + ); + Self { + name, + inner, + neighbors: HashMap::new(), + ip, + + pending_packets, + } + } + + #[inline] + fn hardware_address(&self) -> EthernetAddress { + EthernetAddress(self.inner.mac_address().0) + } + + fn send_to( + inner: &mut AxNetDevice, + dst: EthernetAddress, + size: usize, + f: F, + proto: EthernetProtocol, + ) where + F: FnOnce(&mut [u8]), + { + if let Err(err) = inner.recycle_tx_buffers() { + warn!("recycle_tx_buffers failed: {:?}", err); + return; + } + + let repr = EthernetRepr { + src_addr: EthernetAddress(inner.mac_address().0), + dst_addr: dst, + ethertype: proto, + }; + + let mut tx_buf = match inner.alloc_tx_buffer(repr.buffer_len() + size) { + Ok(buf) => buf, + Err(err) => { + warn!("alloc_tx_buffer failed: {:?}", err); + return; + } + }; + let mut frame = EthernetFrame::new_unchecked(tx_buf.packet_mut()); + repr.emit(&mut frame); + f(frame.payload_mut()); + trace!( + "SEND {} bytes: {:02X?}", + tx_buf.packet_len(), + tx_buf.packet() + ); + if let Err(err) = inner.transmit(tx_buf) { + warn!("transmit failed: {:?}", err); + } + } + + fn handle_frame( + &mut self, + frame: &[u8], + buffer: &mut PacketBuffer<()>, + timestamp: Instant, + ) -> bool { + let frame = EthernetFrame::new_unchecked(frame); + let Ok(repr) = EthernetRepr::parse(&frame) else { + warn!("Dropping malformed Ethernet frame"); + return false; + }; + + if !repr.dst_addr.is_broadcast() + && repr.dst_addr != EMPTY_MAC + && repr.dst_addr != self.hardware_address() + { + return false; + } + + match repr.ethertype { + EthernetProtocol::Ipv4 => { + buffer + .enqueue(frame.payload().len(), ()) + .unwrap() + .copy_from_slice(frame.payload()); + return true; + } + EthernetProtocol::Arp => self.process_arp(frame.payload(), timestamp), + _ => {} + } + + false + } + + fn request_arp(&mut self, target_ip: IpAddress) { + let IpAddress::Ipv4(target_ipv4) = target_ip else { + warn!("IPv6 address ARP is not supported: {}", target_ip); + return; + }; + debug!("Requesting ARP for {}", target_ipv4); + + let arp_repr = ArpRepr::EthernetIpv4 { + operation: ArpOperation::Request, + source_hardware_addr: self.hardware_address(), + source_protocol_addr: self.ip.address(), + target_hardware_addr: EthernetAddress::BROADCAST, + target_protocol_addr: target_ipv4, + }; + + Self::send_to( + &mut self.inner, + EthernetAddress::BROADCAST, + arp_repr.buffer_len(), + |buf| arp_repr.emit(&mut ArpPacket::new_unchecked(buf)), + EthernetProtocol::Arp, + ); + + self.neighbors.insert(target_ip, None); + } + + fn process_arp(&mut self, payload: &[u8], now: Instant) { + let Ok(repr) = ArpPacket::new_checked(payload).and_then(|packet| ArpRepr::parse(&packet)) + else { + warn!("Dropping malformed ARP packet"); + return; + }; + + if let ArpRepr::EthernetIpv4 { + operation, + source_hardware_addr, + source_protocol_addr, + target_hardware_addr, + target_protocol_addr, + } = repr + { + let is_unicast_mac = + target_hardware_addr != EMPTY_MAC && !target_hardware_addr.is_broadcast(); + if is_unicast_mac && self.hardware_address() != target_hardware_addr { + // Only process packet that are for us + return; + } + + if let ArpOperation::Unknown(_) = operation { + return; + } + + if !source_hardware_addr.is_unicast() + || source_protocol_addr.is_broadcast() + || source_protocol_addr.is_multicast() + || source_protocol_addr.is_unspecified() + { + return; + } + if self.ip.address() != target_protocol_addr { + return; + } + + debug!("ARP: {} -> {}", source_protocol_addr, source_hardware_addr); + self.neighbors.insert( + IpAddress::Ipv4(source_protocol_addr), + Some(Neighbor { + hardware_address: source_hardware_addr, + expires_at: now + Self::NEIGHBOR_TTL, + }), + ); + + if let ArpOperation::Request = operation { + let response = ArpRepr::EthernetIpv4 { + operation: ArpOperation::Reply, + source_hardware_addr: self.hardware_address(), + source_protocol_addr: self.ip.address(), + target_hardware_addr: source_hardware_addr, + target_protocol_addr: source_protocol_addr, + }; + + Self::send_to( + &mut self.inner, + source_hardware_addr, + response.buffer_len(), + |buf| response.emit(&mut ArpPacket::new_unchecked(buf)), + EthernetProtocol::Arp, + ); + } + + if self + .pending_packets + .peek() + .is_ok_and(|it| it.0 == &IpAddress::Ipv4(source_protocol_addr)) + { + while let Ok((&next_hop, buf)) = self.pending_packets.peek() { + // TODO: optimize logic such that one long-pending ARP + // request does not block all other packets + + let Some(Some(neighbor)) = self.neighbors.get(&next_hop) else { + break; + }; + if neighbor.expires_at <= now { + // Neighbor is expired, we need to request ARP again + self.request_arp(next_hop); + break; + } + + Self::send_to( + &mut self.inner, + neighbor.hardware_address, + buf.len(), + |b| b.copy_from_slice(buf), + EthernetProtocol::Ipv4, + ); + let _ = self.pending_packets.dequeue(); + } + } + } + } +} + +impl Device for EthernetDevice { + fn name(&self) -> &str { + &self.name + } + + fn recv(&mut self, buffer: &mut PacketBuffer<()>, timestamp: Instant) -> bool { + loop { + let rx_buf = match self.inner.receive() { + Ok(buf) => buf, + Err(err) => { + if !matches!(err, DevError::Again) { + warn!("receive failed: {:?}", err); + } + return false; + } + }; + trace!( + "RECV {} bytes: {:02X?}", + rx_buf.packet_len(), + rx_buf.packet() + ); + + let result = self.handle_frame(rx_buf.packet(), buffer, timestamp); + self.inner.recycle_rx_buffer(rx_buf).unwrap(); + if result { + return true; + } + } + } + + fn send(&mut self, next_hop: IpAddress, packet: &[u8], timestamp: Instant) -> bool { + if next_hop.is_broadcast() || self.ip.broadcast().map(IpAddress::Ipv4) == Some(next_hop) { + Self::send_to( + &mut self.inner, + EthernetAddress::BROADCAST, + packet.len(), + |buf| buf.copy_from_slice(packet), + EthernetProtocol::Ipv4, + ); + return false; + } + + let need_request = match self.neighbors.get(&next_hop) { + Some(Some(neighbor)) => { + if neighbor.expires_at > timestamp { + Self::send_to( + &mut self.inner, + neighbor.hardware_address, + packet.len(), + |buf| buf.copy_from_slice(packet), + EthernetProtocol::Ipv4, + ); + return false; + } else { + true + } + } + // Request already sent + Some(None) => false, + None => true, + }; + // Only send ARP request if we haven't already requested it + if need_request { + self.request_arp(next_hop); + } + if self.pending_packets.is_full() { + warn!("Pending packets buffer is full, dropping packet"); + return false; + } + let Ok(dst_buffer) = self.pending_packets.enqueue(packet.len(), next_hop) else { + warn!("Failed to enqueue packet in pending packets buffer"); + return false; + }; + dst_buffer.copy_from_slice(packet); + false + } + + fn register_waker(&self, waker: &Waker) { + if let Some(irq) = self.inner.irq_num() { + register_irq_waker(irq, waker); + } + } +} diff --git a/modules/axnet/src/device/loopback.rs b/modules/axnet/src/device/loopback.rs new file mode 100644 index 0000000000..af0eba22bf --- /dev/null +++ b/modules/axnet/src/device/loopback.rs @@ -0,0 +1,68 @@ +use alloc::vec; +use core::task::Waker; + +use axpoll::PollSet; +use smoltcp::{ + storage::{PacketBuffer, PacketMetadata}, + time::Instant, + wire::IpAddress, +}; + +use crate::{ + consts::{SOCKET_BUFFER_SIZE, STANDARD_MTU}, + device::Device, +}; + +pub struct LoopbackDevice { + buffer: PacketBuffer<'static, ()>, + poll: PollSet, +} +impl LoopbackDevice { + pub fn new() -> Self { + let buffer = PacketBuffer::new( + vec![PacketMetadata::EMPTY; SOCKET_BUFFER_SIZE], + vec![0u8; STANDARD_MTU * SOCKET_BUFFER_SIZE], + ); + Self { + buffer, + poll: PollSet::new(), + } + } +} + +impl Device for LoopbackDevice { + fn name(&self) -> &str { + "lo" + } + + fn recv(&mut self, buffer: &mut PacketBuffer<()>, _timestamp: Instant) -> bool { + self.buffer.dequeue().ok().is_some_and(|(_, rx_buf)| { + buffer + .enqueue(rx_buf.len(), ()) + .unwrap() + .copy_from_slice(rx_buf); + true + }) + } + + fn send(&mut self, next_hop: IpAddress, packet: &[u8], _timestamp: Instant) -> bool { + match self.buffer.enqueue(packet.len(), ()) { + Ok(tx_buf) => { + tx_buf.copy_from_slice(packet); + self.poll.wake(); + true + } + Err(_) => { + warn!( + "Loopback device buffer is full, dropping packet to {}", + next_hop + ); + false + } + } + } + + fn register_waker(&self, waker: &Waker) { + self.poll.register(waker); + } +} diff --git a/modules/axnet/src/device/mod.rs b/modules/axnet/src/device/mod.rs new file mode 100644 index 0000000000..bcdb2fb48a --- /dev/null +++ b/modules/axnet/src/device/mod.rs @@ -0,0 +1,28 @@ +use core::task::Waker; + +use smoltcp::{storage::PacketBuffer, time::Instant, wire::IpAddress}; + +mod ethernet; +mod loopback; +#[cfg(feature = "vsock")] +mod vsock; + +pub use ethernet::*; +pub use loopback::*; +#[cfg(feature = "vsock")] +pub use vsock::*; + +pub trait Device: Send + Sync { + #[allow(unused)] + fn name(&self) -> &str; + + fn recv(&mut self, buffer: &mut PacketBuffer<()>, timestamp: Instant) -> bool; + /// Sends a packet to the next hop. + /// + /// Returns `true` if this operation resulted in the readiness of receive + /// operation. This is true for loopback devices and can be used to speed + /// up packet processing. + fn send(&mut self, next_hop: IpAddress, packet: &[u8], timestamp: Instant) -> bool; + + fn register_waker(&self, waker: &Waker); +} diff --git a/modules/axnet/src/device/vsock.rs b/modules/axnet/src/device/vsock.rs new file mode 100644 index 0000000000..d88590f8d0 --- /dev/null +++ b/modules/axnet/src/device/vsock.rs @@ -0,0 +1,216 @@ +use core::{ + sync::atomic::{AtomicBool, AtomicU64, Ordering}, + time::Duration, +}; + +use axdriver::prelude::*; +use axerrno::{AxError, AxResult, ax_bail}; +use axsync::Mutex; +use axtask::future::{block_on, interruptible}; + +use crate::{alloc::string::ToString, vsock::connection_manager::VSOCK_CONN_MANAGER}; + +// we need a global and static only one vsock device +static VSOCK_DEVICE: Mutex> = Mutex::new(None); + +/// Registers a vsock device. Only one vsock device can be registered. +pub fn register_vsock_device(dev: AxVsockDevice) -> AxResult { + let mut guard = VSOCK_DEVICE.lock(); + if guard.is_some() { + ax_bail!(AlreadyExists, "vsock device already registered"); + } + *guard = Some(dev); + drop(guard); + Ok(()) +} + +static POLL_REF_COUNT: Mutex = Mutex::new(0); +static POLL_TASK_RUNNING: AtomicBool = AtomicBool::new(false); +static POLL_FREQUENCY: PollFrequencyController = PollFrequencyController::new(); + +struct PollFrequencyController { + consecutive_idle: AtomicU64, +} + +impl PollFrequencyController { + const fn new() -> Self { + Self { + consecutive_idle: AtomicU64::new(0), + } + } + + fn current_interval(&self) -> Duration { + let idle = self.consecutive_idle.load(Ordering::Relaxed); + let interval_us = match idle { + 0..=3 => 100, // 3 :100μs + 4..=10 => 500, // 4-10 :500μs + 11..=20 => 2_000, // 11-20 :2ms + _ => 10_000, // 20+ :10ms + }; + Duration::from_micros(interval_us) + } + + fn on_event(&self) { + self.consecutive_idle.store(0, Ordering::Release); + } + + fn on_idle(&self) { + self.consecutive_idle.fetch_add(1, Ordering::Relaxed); + } + + fn stats(&self) -> (u64, u64) { + let idle = self.consecutive_idle.load(Ordering::Relaxed); + let interval = self.current_interval().as_micros() as u64; + (idle, interval) + } +} + +pub fn start_vsock_poll() { + let mut count = POLL_REF_COUNT.lock(); + *count += 1; + let new_count = *count; + debug!("start_vsock_poll: ref_count -> {}", new_count); + if new_count == 1 { + if !POLL_TASK_RUNNING.swap(true, Ordering::SeqCst) { + drop(count); + debug!("Starting vsock poll task"); + axtask::spawn_with_name(vsock_poll_loop, "vsock-poll".to_string()); + } else { + warn!("Poll task already running!"); + } + } +} + +pub fn stop_vsock_poll() { + let mut count = POLL_REF_COUNT.lock(); + if *count == 0 { + // this should not happen, log a warning + warn!("stop_vsock_poll called but ref_count already 0"); + return; + } + *count -= 1; + let new_count = *count; + debug!("stop_vsock_poll: ref_count -> {}", new_count); +} + +fn vsock_poll_loop() { + loop { + let ref_count = *POLL_REF_COUNT.lock(); + if ref_count == 0 { + POLL_TASK_RUNNING.store(false, Ordering::SeqCst); + debug!("Vsock poll task exiting (no active connections)"); + break; + } + let _ = block_on(interruptible(poll_interfaces_adaptive())); + } +} + +async fn poll_interfaces_adaptive() -> AxResult<()> { + let has_events = poll_vsock_interfaces()?; + + if has_events { + POLL_FREQUENCY.on_event(); + } else { + POLL_FREQUENCY.on_idle(); + } + + let interval = POLL_FREQUENCY.current_interval(); + + let (idle_count, interval_us) = POLL_FREQUENCY.stats(); + if idle_count > 0 && idle_count % 10 == 0 { + trace!( + "Poll frequency: idle_count={}, interval={}μs", + idle_count, interval_us + ); + } + axtask::future::sleep(interval).await; + Ok(()) +} + +fn poll_vsock_interfaces() -> AxResult { + let mut guard = VSOCK_DEVICE.lock(); + let dev = guard.as_mut().ok_or(AxError::NotFound)?; + let mut event_count = 0; + let mut buf = alloc::vec![0; 0x1000]; // 4KiB buffer for receiving data + + loop { + match dev.poll_event(&mut buf) { + Ok(None) => break, // no more events + Ok(Some(event)) => { + event_count += 1; + handle_vsock_event(event, &buf); + } + Err(e) => { + info!("Failed to poll vsock event: {:?}", e); + break; + } + } + } + Ok(event_count > 0) +} + +fn handle_vsock_event(event: VsockDriverEvent, buf: &[u8]) { + let mut manager = VSOCK_CONN_MANAGER.lock(); + debug!("Handling vsock event: {:?}", event); + + match event { + VsockDriverEvent::ConnectionRequest(conn_id) => { + let _ = manager.on_connection_request(conn_id); + } + + VsockDriverEvent::Received(conn_id, len) => { + let _ = manager.on_data_received(conn_id, &buf[..len]); + } + + VsockDriverEvent::Disconnected(conn_id) => { + let _ = manager.on_disconnected(conn_id); + } + + VsockDriverEvent::Connected(conn_id) => { + let _ = manager.on_connected(conn_id); + } + + VsockDriverEvent::Unknown => warn!("Received unknown vsock event"), + } +} + +pub fn vsock_listen(addr: VsockAddr) -> AxResult<()> { + let mut guard = VSOCK_DEVICE.lock(); + let dev = guard.as_mut().ok_or(AxError::NotFound)?; + dev.listen(addr.port); + Ok(()) +} + +fn map_dev_err(e: DevError) -> AxError { + match e { + DevError::AlreadyExists => AxError::AlreadyExists, + DevError::Again => AxError::WouldBlock, + DevError::InvalidParam => AxError::InvalidInput, + DevError::Io => AxError::Io, + _ => AxError::BadState, + } +} + +pub fn vsock_connect(conn_id: VsockConnId) -> AxResult<()> { + let mut guard = VSOCK_DEVICE.lock(); + let dev = guard.as_mut().ok_or(AxError::NotFound)?; + dev.connect(conn_id).map_err(map_dev_err) +} + +pub fn vsock_send(conn_id: VsockConnId, buf: &[u8]) -> AxResult { + let mut guard = VSOCK_DEVICE.lock(); + let dev = guard.as_mut().ok_or(AxError::NotFound)?; + dev.send(conn_id, buf).map_err(map_dev_err) +} + +pub fn vsock_disconnect(conn_id: VsockConnId) -> AxResult<()> { + let mut guard = VSOCK_DEVICE.lock(); + let dev = guard.as_mut().ok_or(AxError::NotFound)?; + dev.disconnect(conn_id).map_err(map_dev_err) +} + +pub fn vsock_guest_cid() -> AxResult { + let mut guard = VSOCK_DEVICE.lock(); + let dev = guard.as_mut().ok_or(AxError::NotFound)?; + Ok(dev.guest_cid()) +} diff --git a/modules/axnet/src/general.rs b/modules/axnet/src/general.rs new file mode 100644 index 0000000000..e38b4a21b3 --- /dev/null +++ b/modules/axnet/src/general.rs @@ -0,0 +1,148 @@ +use core::{ + sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering}, + task::Waker, + time::Duration, +}; + +use axerrno::AxResult; +use axpoll::{IoEvents, Pollable}; +use axtask::future::{block_on, poll_io, timeout}; + +use crate::{ + SERVICE, + options::{Configurable, GetSocketOption, SetSocketOption}, +}; + +/// General options for all sockets. +pub(crate) struct GeneralOptions { + /// Whether the socket is non-blocking. + nonblock: AtomicBool, + /// Whether the socket should reuse the address. + reuse_address: AtomicBool, + + send_timeout_nanos: AtomicU64, + recv_timeout_nanos: AtomicU64, + + device_mask: AtomicU32, +} +impl Default for GeneralOptions { + fn default() -> Self { + Self::new() + } +} +impl GeneralOptions { + pub fn new() -> Self { + Self { + nonblock: AtomicBool::new(false), + reuse_address: AtomicBool::new(false), + + send_timeout_nanos: AtomicU64::new(0), + recv_timeout_nanos: AtomicU64::new(0), + + device_mask: AtomicU32::new(0), + } + } + + pub fn nonblocking(&self) -> bool { + self.nonblock.load(Ordering::Relaxed) + } + + pub fn reuse_address(&self) -> bool { + self.reuse_address.load(Ordering::Relaxed) + } + + pub fn send_timeout(&self) -> Option { + let nanos = self.send_timeout_nanos.load(Ordering::Relaxed); + (nanos > 0).then(|| Duration::from_nanos(nanos)) + } + + pub fn recv_timeout(&self) -> Option { + let nanos = self.recv_timeout_nanos.load(Ordering::Relaxed); + (nanos > 0).then(|| Duration::from_nanos(nanos)) + } + + pub fn set_device_mask(&self, mask: u32) { + self.device_mask.store(mask, Ordering::Release); + } + + pub fn device_mask(&self) -> u32 { + self.device_mask.load(Ordering::Acquire) + } + + pub fn register_waker(&self, waker: &Waker) { + SERVICE.lock().register_waker(self.device_mask(), waker); + } + + pub fn send_poller<'a, P: Pollable, F: FnMut() -> AxResult, T>( + &self, + pollable: &'a P, + f: F, + ) -> AxResult { + block_on(timeout( + self.send_timeout(), + poll_io(pollable, IoEvents::OUT, self.nonblocking(), f), + ))? + } + + pub fn recv_poller<'a, P: Pollable, F: FnMut() -> AxResult, T>( + &self, + pollable: &'a P, + f: F, + ) -> AxResult { + block_on(timeout( + self.recv_timeout(), + poll_io(pollable, IoEvents::IN, self.nonblocking(), f), + ))? + } +} +impl Configurable for GeneralOptions { + fn get_option_inner(&self, option: &mut GetSocketOption) -> AxResult { + use GetSocketOption as O; + match option { + O::Error(error) => { + // TODO(mivik): actual logic + **error = 0; + } + O::NonBlocking(nonblock) => { + **nonblock = self.nonblocking(); + } + O::ReuseAddress(reuse) => { + **reuse = self.reuse_address(); + } + O::SendTimeout(timeout) => { + **timeout = Duration::from_nanos(self.send_timeout_nanos.load(Ordering::Relaxed)); + } + O::ReceiveTimeout(timeout) => { + **timeout = Duration::from_nanos(self.recv_timeout_nanos.load(Ordering::Relaxed)); + } + _ => return Ok(false), + } + Ok(true) + } + + fn set_option_inner(&self, option: SetSocketOption) -> AxResult { + use SetSocketOption as O; + + match option { + O::NonBlocking(nonblock) => { + self.nonblock.store(*nonblock, Ordering::Relaxed); + } + O::ReuseAddress(reuse) => { + self.reuse_address.store(*reuse, Ordering::Relaxed); + } + O::SendTimeout(timeout) => { + self.send_timeout_nanos + .store(timeout.as_nanos() as u64, Ordering::Relaxed); + } + O::ReceiveTimeout(timeout) => { + self.recv_timeout_nanos + .store(timeout.as_nanos() as u64, Ordering::Relaxed); + } + O::SendBuffer(_) | O::ReceiveBuffer(_) => { + // TODO(mivik): implement buffer size options + } + _ => return Ok(false), + } + Ok(true) + } +} diff --git a/modules/axnet/src/lib.rs b/modules/axnet/src/lib.rs index 3e13a4e50c..3dc4e2bc78 100644 --- a/modules/axnet/src/lib.rs +++ b/modules/axnet/src/lib.rs @@ -1,4 +1,4 @@ -//! [ArceOS](https://github.com/arceos-org/arceos) network module. +//! [ArceOS](https://github.com/rcore-os/arceos) network module. //! //! It provides unified networking primitives for TCP/UDP communication //! using various underlying network stacks. Currently, only [smoltcp] is @@ -10,53 +10,126 @@ //! - [`UdpSocket`]: A UDP socket that provides POSIX-like APIs. //! - [`dns_query`]: Function for DNS query. //! -//! # Cargo Features -//! -//! - `smoltcp`: Use [smoltcp] as the underlying network stack. This is enabled -//! by default. -//! //! [smoltcp]: https://github.com/smoltcp-rs/smoltcp #![no_std] +#![feature(ip_from)] +#![feature(maybe_uninit_slice)] #[macro_use] extern crate log; extern crate alloc; -cfg_if::cfg_if! { - if #[cfg(feature = "smoltcp")] { - mod smoltcp_impl; - use smoltcp_impl as net_impl; - } -} +mod consts; +mod device; +mod general; +mod listen_table; +pub mod options; +mod router; +mod service; +mod socket; +pub(crate) mod state; +pub mod tcp; +pub mod udp; +pub mod unix; +#[cfg(feature = "vsock")] +pub mod vsock; +mod wrapper; -pub use self::net_impl::TcpSocket; -pub use self::net_impl::UdpSocket; -pub use self::net_impl::{bench_receive, bench_transmit}; -pub use self::net_impl::{dns_query, poll_interfaces}; +use alloc::{borrow::ToOwned, boxed::Box}; use axdriver::{AxDeviceContainer, prelude::*}; +use axsync::Mutex; +use lazyinit::LazyInit; +use smoltcp::wire::{EthernetAddress, Ipv4Address, Ipv4Cidr}; +pub use socket::*; + +use crate::{ + consts::{GATEWAY, IP, IP_PREFIX}, + device::{EthernetDevice, LoopbackDevice}, + listen_table::ListenTable, + router::{Router, Rule}, + service::Service, + wrapper::SocketSetWrapper, +}; + +static LISTEN_TABLE: LazyInit = LazyInit::new(); +static SOCKET_SET: LazyInit = LazyInit::new(); + +static SERVICE: LazyInit> = LazyInit::new(); /// Initializes the network subsystem by NIC devices. pub fn init_network(mut net_devs: AxDeviceContainer) { info!("Initialize network subsystem..."); - if let Some(dev) = net_devs.take_one() { + let mut router = Router::new(); + let lo_dev = router.add_device(Box::new(LoopbackDevice::new())); + + let lo_ip = Ipv4Cidr::new(Ipv4Address::new(127, 0, 0, 1), 8); + router.add_rule(Rule::new( + lo_ip.into(), + None, + lo_dev, + lo_ip.address().into(), + )); + + let eth0_ip = if let Some(dev) = net_devs.take_one() { info!(" use NIC 0: {:?}", dev.device_name()); - net_impl::init(dev); + + let eth0_address = EthernetAddress(dev.mac_address().0); + let eth0_ip = Ipv4Cidr::new(IP.parse().expect("Invalid IPv4 address"), IP_PREFIX); + + let eth0_dev = router.add_device(Box::new(EthernetDevice::new( + "eth0".to_owned(), + dev, + eth0_ip, + ))); + + router.add_rule(Rule::new( + Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0).into(), + Some(GATEWAY.parse().expect("Invalid gateway address")), + eth0_dev, + eth0_ip.address().into(), + )); + + info!("eth0:"); + info!(" mac: {}", eth0_address); + info!(" ip: {}", eth0_ip); + + Some(eth0_ip) } else { warn!(" No network device found!"); - } + None + }; + + let mut service = Service::new(router); + service.iface.update_ip_addrs(|ip_addrs| { + ip_addrs.push(lo_ip.into()).unwrap(); + if let Some(eth0_ip) = eth0_ip { + ip_addrs.push(eth0_ip.into()).unwrap(); + } + }); + SERVICE.init_once(Mutex::new(service)); + + SOCKET_SET.init_once(SocketSetWrapper::new()); + LISTEN_TABLE.init_once(ListenTable::new()); } -/// Initializes vsock devices. +/// Init vsock subsystem by vsock devices. #[cfg(feature = "vsock")] pub fn init_vsock(mut vsock_devs: AxDeviceContainer) { + use crate::device::register_vsock_device; info!("Initialize vsock subsystem..."); if let Some(dev) = vsock_devs.take_one() { info!(" use vsock 0: {:?}", dev.device_name()); - warn!(" vsock not implemented yet!"); + if let Err(e) = register_vsock_device(dev) { + warn!("Failed to initialize vsock device: {:?}", e); + } } else { warn!(" No vsock device found!"); } } + +pub fn poll_interfaces() { + while SERVICE.lock().poll(&mut SOCKET_SET.inner.lock()) {} +} diff --git a/modules/axnet/src/listen_table.rs b/modules/axnet/src/listen_table.rs new file mode 100644 index 0000000000..afb2e108f9 --- /dev/null +++ b/modules/axnet/src/listen_table.rs @@ -0,0 +1,168 @@ +use alloc::{boxed::Box, collections::VecDeque, sync::Arc, vec}; +use core::ops::DerefMut; + +use axerrno::{AxError, AxResult}; +use axsync::Mutex; +use smoltcp::{ + iface::{SocketHandle, SocketSet}, + socket::tcp::{self, SocketBuffer, State}, + wire::{IpEndpoint, IpListenEndpoint}, +}; + +use crate::{ + SOCKET_SET, + consts::{LISTEN_QUEUE_SIZE, TCP_RX_BUF_LEN, TCP_TX_BUF_LEN}, +}; + +const PORT_NUM: usize = 65536; + +struct ListenTableEntry { + listen_endpoint: IpListenEndpoint, + syn_queue: VecDeque, +} + +impl ListenTableEntry { + pub fn new(listen_endpoint: IpListenEndpoint) -> Self { + Self { + listen_endpoint, + syn_queue: VecDeque::with_capacity(LISTEN_QUEUE_SIZE), + } + } +} + +impl Drop for ListenTableEntry { + fn drop(&mut self) { + for &handle in &self.syn_queue { + SOCKET_SET.remove(handle); + } + } +} + +pub struct ListenTable { + tcp: Box<[Arc>>>]>, +} + +impl ListenTable { + pub fn new() -> Self { + let tcp = unsafe { + let mut buf = Box::new_uninit_slice(PORT_NUM); + for i in 0..PORT_NUM { + buf[i].write(Arc::default()); + } + buf.assume_init() + }; + Self { tcp } + } + + pub fn can_listen(&self, port: u16) -> bool { + self.tcp[port as usize].lock().is_none() + } + + pub fn listen(&self, listen_endpoint: IpListenEndpoint) -> AxResult { + let port = listen_endpoint.port; + assert_ne!(port, 0); + let mut entry = self.tcp[port as usize].lock(); + if entry.is_none() { + *entry = Some(Box::new(ListenTableEntry::new(listen_endpoint))); + Ok(()) + } else { + warn!("socket already listening on port {port}"); + Err(AxError::AddrInUse) + } + } + + pub fn unlisten(&self, port: u16) { + debug!("TCP socket unlisten on {}", port); + *self.tcp[port as usize].lock() = None; + } + + fn listen_entry(&self, port: u16) -> Arc>>> { + self.tcp[port as usize].clone() + } + + pub fn can_accept(&self, port: u16) -> AxResult { + if let Some(entry) = self.listen_entry(port).lock().as_ref() { + Ok(entry.syn_queue.iter().any(|&handle| is_connected(handle))) + } else { + warn!("accept before listen"); + Err(AxError::InvalidInput) + } + } + + pub fn accept(&self, port: u16) -> AxResult { + let entry = self.listen_entry(port); + let mut table = entry.lock(); + let Some(entry) = table.deref_mut() else { + warn!("accept before listen"); + return Err(AxError::InvalidInput); + }; + + let syn_queue: &mut VecDeque = &mut entry.syn_queue; + let idx = syn_queue + .iter() + .enumerate() + .find_map(|(idx, &handle)| is_connected(handle).then_some(idx)) + .ok_or(AxError::WouldBlock)?; // wait for connection + if idx > 0 { + warn!( + "slow SYN queue enumeration: index = {}, len = {}!", + idx, + syn_queue.len() + ); + } + let handle = syn_queue.swap_remove_front(idx).unwrap(); + // If the connection is reset, return ConnectionReset error + // Otherwise, return the handle and the address tuple + if is_closed(handle) { + warn!("accept failed: connection reset"); + Err(AxError::ConnectionReset) + } else { + Ok(handle) + } + } + + pub fn incoming_tcp_packet( + &self, + src: IpEndpoint, + dst: IpEndpoint, + sockets: &mut SocketSet<'_>, + ) { + if let Some(entry) = self.listen_entry(dst.port).lock().deref_mut() { + // TODO(mivik): accept address check + if entry.syn_queue.len() >= LISTEN_QUEUE_SIZE { + // SYN queue is full, drop the packet + warn!("SYN queue overflow!"); + return; + } + + let mut socket = smoltcp::socket::tcp::Socket::new( + SocketBuffer::new(vec![0; TCP_RX_BUF_LEN]), + SocketBuffer::new(vec![0; TCP_TX_BUF_LEN]), + ); + if let Err(err) = socket.listen(IpListenEndpoint { + addr: None, + port: dst.port, + }) { + warn!("Failed to listen on {}: {:?}", entry.listen_endpoint, err); + return; + } + let handle = sockets.add(socket); + debug!( + "TCP socket {}: prepare for connection {} -> {}", + handle, src, entry.listen_endpoint + ); + entry.syn_queue.push_back(handle); + } + } +} + +fn is_connected(handle: SocketHandle) -> bool { + SOCKET_SET.with_socket::(handle, |socket| { + !matches!(socket.state(), State::Listen | State::SynReceived) + }) +} + +fn is_closed(handle: SocketHandle) -> bool { + SOCKET_SET + .with_socket::(handle, |socket| matches!(socket.state(), State::Closed)) +} diff --git a/modules/axnet/src/options.rs b/modules/axnet/src/options.rs new file mode 100644 index 0000000000..6348f184c2 --- /dev/null +++ b/modules/axnet/src/options.rs @@ -0,0 +1,99 @@ +use core::time::Duration; + +use axerrno::{AxError, AxResult, LinuxError}; +use enum_dispatch::enum_dispatch; + +macro_rules! define_options { + ($($name:ident($value:ty),)*) => { + /// Operation to get a socket option. + /// + /// See [`Configurable::get_option`]. + pub enum GetSocketOption<'a> { + $( + $name(&'a mut $value), + )* + } + + /// Operation to set a socket option. + /// + /// See [`Configurable::set_option`]. + #[derive(Clone, Copy)] + pub enum SetSocketOption<'a> { + $( + $name(&'a $value), + )* + } + }; +} + +/// Corresponds to `struct ucred` in Linux. +#[repr(C)] +#[derive(Default, Debug, Clone)] +pub struct UnixCredentials { + pub pid: u32, + pub uid: u32, + pub gid: u32, +} +impl UnixCredentials { + pub fn new(pid: u32) -> Self { + UnixCredentials { + pid, + uid: 0, + gid: 0, + } + } +} + +define_options! { + // ---- Socket level options (SO_*) ---- + ReuseAddress(bool), + Error(i32), + DontRoute(bool), + SendBuffer(usize), + ReceiveBuffer(usize), + KeepAlive(bool), + SendTimeout(Duration), + ReceiveTimeout(Duration), + SendBufferForce(usize), + PassCredentials(bool), + PeerCredentials(UnixCredentials), + + // --- TCP level options (TCP_*) ---- + NoDelay(bool), + MaxSegment(usize), + TcpInfo(()), + + // ---- IP level options (IP_*) ---- + Ttl(u8), + + // ---- Extra options ---- + NonBlocking(bool), +} + +/// Trait for configurable socket-like objects. +#[enum_dispatch] +pub trait Configurable { + /// Get a socket option, returns `true` if the socket supports the option. + fn get_option_inner(&self, opt: &mut GetSocketOption) -> AxResult; + /// Set a socket option, returns `true` if the socket supports the option. + fn set_option_inner(&self, opt: SetSocketOption) -> AxResult; + + fn get_option(&self, mut opt: GetSocketOption) -> AxResult { + self.get_option_inner(&mut opt).and_then(|supported| { + if !supported { + Err(AxError::from(LinuxError::ENOPROTOOPT)) + } else { + Ok(()) + } + }) + } + fn set_option(&self, opt: SetSocketOption) -> AxResult { + self.set_option_inner(opt).and_then(|supported| { + if !supported { + Err(AxError::from(LinuxError::ENOPROTOOPT)) + } else { + Ok(()) + } + }) + } +} diff --git a/modules/axnet/src/router.rs b/modules/axnet/src/router.rs new file mode 100644 index 0000000000..a6e2f157ce --- /dev/null +++ b/modules/axnet/src/router.rs @@ -0,0 +1,244 @@ +use alloc::{boxed::Box, vec, vec::Vec}; + +use smoltcp::{ + iface::SocketSet, + phy::{DeviceCapabilities, Medium}, + storage::PacketMetadata, + time::Instant, + wire::{IpAddress, IpCidr, IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket}, +}; + +use crate::{ + LISTEN_TABLE, + consts::{SOCKET_BUFFER_SIZE, STANDARD_MTU}, + device::Device, +}; + +#[derive(Debug)] +pub struct Rule { + pub filter: IpCidr, + pub via: Option, + pub dev: usize, + pub src: IpAddress, +} + +impl Rule { + pub fn new(filter: IpCidr, via: Option, dev: usize, src: IpAddress) -> Self { + Self { + filter, + via, + dev, + src, + } + } +} + +type PacketBuffer = smoltcp::storage::PacketBuffer<'static, ()>; + +// TODO(mivik): optimize +pub struct RouteTable { + rules: Vec, +} +impl RouteTable { + pub fn new() -> Self { + Self { rules: Vec::new() } + } + + pub fn add_rule(&mut self, rule: Rule) { + let idx = self + .rules + .binary_search_by(|it| rule.filter.prefix_len().cmp(&it.filter.prefix_len())) + .unwrap_or_else(|idx| idx); + self.rules.insert(idx, rule); + } + + pub fn lookup(&self, dst: &IpAddress) -> Option<&Rule> { + self.rules + .iter() + .find(|rule| rule.filter.contains_addr(dst)) + } +} + +pub struct Router { + rx_buffer: PacketBuffer, + tx_buffer: PacketBuffer, + pub(crate) devices: Vec>, + pub(crate) table: RouteTable, +} +impl Router { + pub fn new() -> Self { + let rx_buffer = PacketBuffer::new( + vec![PacketMetadata::EMPTY; SOCKET_BUFFER_SIZE], + vec![0u8; STANDARD_MTU * SOCKET_BUFFER_SIZE], + ); + let tx_buffer = PacketBuffer::new( + vec![PacketMetadata::EMPTY; SOCKET_BUFFER_SIZE], + vec![0u8; STANDARD_MTU * SOCKET_BUFFER_SIZE], + ); + Self { + rx_buffer, + tx_buffer, + devices: Vec::new(), + table: RouteTable::new(), + } + } + + pub fn add_rule(&mut self, rule: Rule) { + self.table.add_rule(rule); + } + + pub fn add_device(&mut self, device: Box) -> usize { + self.devices.push(device); + self.devices.len() - 1 + } + + pub fn poll(&mut self, timestamp: Instant) { + for dev in &mut self.devices { + while !self.rx_buffer.is_full() && dev.recv(&mut self.rx_buffer, timestamp) {} + } + } + + pub fn dispatch(&mut self, timestamp: Instant) -> bool { + let mut poll_next = false; + while let Ok(((), packet)) = self.tx_buffer.dequeue() { + match IpVersion::of_packet(packet).expect("got invalid IP packet") { + IpVersion::Ipv4 => { + let packet = smoltcp::wire::Ipv4Packet::new_checked(packet) + .expect("got invalid IPv4 packet"); + let dst_addr = IpAddress::Ipv4(packet.dst_addr()); + if packet.dst_addr().is_broadcast() { + let buf = packet.into_inner(); + for dev in &mut self.devices { + poll_next |= dev.send(dst_addr, buf, timestamp); + } + } else { + let Some(rule) = self.table.lookup(&dst_addr) else { + warn!("No route found for destination: {}", dst_addr); + continue; + }; + assert_eq!(rule.src, IpAddress::Ipv4(packet.src_addr())); + + let next_hop = rule.via.unwrap_or(dst_addr); + let dev = &mut self.devices[rule.dev]; + poll_next |= dev.send(next_hop, packet.into_inner(), timestamp); + } + } + IpVersion::Ipv6 => { + let packet = smoltcp::wire::Ipv6Packet::new_checked(packet) + .expect("got invalid IPv6 packet"); + let dst_addr = IpAddress::Ipv6(packet.dst_addr()); + if packet.dst_addr().is_multicast() { + let buf = packet.into_inner(); + for dev in &mut self.devices { + poll_next |= dev.send(dst_addr, buf, timestamp); + } + } else { + let Some(rule) = self.table.lookup(&dst_addr) else { + warn!("No route found for destination: {}", dst_addr); + continue; + }; + assert_eq!(rule.src, IpAddress::Ipv6(packet.src_addr())); + + let next_hop = rule.via.unwrap_or(dst_addr); + let dev = &mut self.devices[rule.dev]; + poll_next |= dev.send(next_hop, packet.into_inner(), timestamp); + } + } + } + } + poll_next + } +} + +pub struct TxToken<'a>(&'a mut PacketBuffer); + +impl smoltcp::phy::TxToken for TxToken<'_> { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + f(self + .0 + .enqueue(len, ()) + .expect("This was checked before creating the TxToken")) + } +} + +fn snoop_tcp_packet(buf: &[u8], sockets: &mut SocketSet<'_>) { + let (protocol, src_addr, dst_addr, payload) = match IpVersion::of_packet(buf).unwrap() { + IpVersion::Ipv4 => { + let packet = Ipv4Packet::new_unchecked(buf); + ( + packet.next_header(), + IpAddress::Ipv4(packet.src_addr()), + IpAddress::Ipv4(packet.dst_addr()), + packet.payload(), + ) + } + IpVersion::Ipv6 => { + let packet = Ipv6Packet::new_unchecked(buf); + ( + packet.next_header(), + IpAddress::Ipv6(packet.src_addr()), + IpAddress::Ipv6(packet.dst_addr()), + packet.payload(), + ) + } + }; + if protocol == IpProtocol::Tcp { + let tcp_packet = TcpPacket::new_unchecked(payload); + let src_addr = (src_addr, tcp_packet.src_port()).into(); + let dst_addr = (dst_addr, tcp_packet.dst_port()).into(); + let is_first = tcp_packet.syn() && !tcp_packet.ack(); + if is_first { + LISTEN_TABLE.incoming_tcp_packet(src_addr, dst_addr, sockets); + } + } +} + +pub struct RxToken<'a>(&'a [u8]); + +impl<'a> smoltcp::phy::RxToken for RxToken<'a> { + fn consume(self, f: F) -> R + where + F: FnOnce(&[u8]) -> R, + { + f(self.0) + } + + fn preprocess(&self, sockets: &mut SocketSet) { + snoop_tcp_packet(self.0, sockets); + } +} + +impl smoltcp::phy::Device for Router { + type RxToken<'a> = RxToken<'a>; + type TxToken<'a> = TxToken<'a>; + + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + if self.rx_buffer.is_empty() || self.tx_buffer.is_full() { + None + } else { + Some(( + RxToken(self.rx_buffer.dequeue().unwrap().1), + TxToken(&mut self.tx_buffer), + )) + } + } + + fn transmit(&mut self, _timestamp: Instant) -> Option> { + if self.tx_buffer.is_full() { + None + } else { + Some(TxToken(&mut self.tx_buffer)) + } + } + + fn capabilities(&self) -> DeviceCapabilities { + let mut caps = DeviceCapabilities::default(); + caps.medium = Medium::Ip; + caps.max_transmission_unit = STANDARD_MTU; + caps.max_burst_size = Some(SOCKET_BUFFER_SIZE); + caps + } +} diff --git a/modules/axnet/src/service.rs b/modules/axnet/src/service.rs new file mode 100644 index 0000000000..6f54027d28 --- /dev/null +++ b/modules/axnet/src/service.rs @@ -0,0 +1,90 @@ +use alloc::boxed::Box; +use core::{ + pin::Pin, + task::{Context, Waker}, +}; + +use axhal::time::{NANOS_PER_MICROS, TimeValue, wall_time_nanos}; +use axtask::future::sleep_until; +use smoltcp::{ + iface::{Interface, SocketSet}, + time::Instant, + wire::{HardwareAddress, IpAddress, IpListenEndpoint}, +}; + +use crate::{SOCKET_SET, router::Router}; + +fn now() -> Instant { + Instant::from_micros_const((wall_time_nanos() / NANOS_PER_MICROS) as i64) +} + +pub struct Service { + pub iface: Interface, + router: Router, + timeout: Option + Send>>>, +} +impl Service { + pub fn new(mut router: Router) -> Self { + let config = smoltcp::iface::Config::new(HardwareAddress::Ip); + let iface = Interface::new(config, &mut router, now()); + + Self { + iface, + router, + timeout: None, + } + } + + pub fn poll(&mut self, sockets: &mut SocketSet) -> bool { + let timestamp = now(); + + self.router.poll(timestamp); + self.iface.poll(timestamp, &mut self.router, sockets); + self.router.dispatch(timestamp) + } + + pub fn get_source_address(&self, dst_addr: &IpAddress) -> IpAddress { + let Some(rule) = self.router.table.lookup(dst_addr) else { + panic!("no route to destination: {dst_addr}"); + }; + rule.src + } + + pub fn device_mask_for(&self, endpoint: &IpListenEndpoint) -> u32 { + match endpoint.addr { + Some(addr) => self + .router + .table + .lookup(&addr) + .map_or(0, |it| 1u32 << it.dev), + None => u32::MAX, + } + } + + pub fn register_waker(&mut self, mask: u32, waker: &Waker) { + let next = self.iface.poll_at(now(), &mut SOCKET_SET.inner.lock()); + + if let Some(t) = next { + let next = TimeValue::from_micros(t.total_micros() as _); + + // drop old timeout future + self.timeout = None; + + let mut fut = Box::pin(sleep_until(next)); + let mut cx = Context::from_waker(waker); + + if fut.as_mut().poll(&mut cx).is_ready() { + waker.wake_by_ref(); + return; + } else { + self.timeout = Some(fut); + } + } + + for (i, device) in self.router.devices.iter().enumerate() { + if mask & (1 << i) != 0 { + device.register_waker(waker); + } + } + } +} diff --git a/modules/axnet/src/smoltcp_impl/addr.rs b/modules/axnet/src/smoltcp_impl/addr.rs deleted file mode 100644 index 7ceaf1a09d..0000000000 --- a/modules/axnet/src/smoltcp_impl/addr.rs +++ /dev/null @@ -1,4 +0,0 @@ -use smoltcp::wire::{IpAddress, IpEndpoint}; - -pub const UNSPECIFIED_IP: IpAddress = IpAddress::v4(0, 0, 0, 0); -pub const UNSPECIFIED_ENDPOINT: IpEndpoint = IpEndpoint::new(UNSPECIFIED_IP, 0); diff --git a/modules/axnet/src/smoltcp_impl/bench.rs b/modules/axnet/src/smoltcp_impl/bench.rs deleted file mode 100644 index 48cd7e3615..0000000000 --- a/modules/axnet/src/smoltcp_impl/bench.rs +++ /dev/null @@ -1,68 +0,0 @@ -use super::{AxNetRxToken, AxNetTxToken, STANDARD_MTU}; -use super::{DeviceWrapper, InterfaceWrapper}; -use smoltcp::phy::{Device, RxToken, TxToken}; - -const GB: usize = 1000 * MB; -const MB: usize = 1000 * KB; -const KB: usize = 1000; - -impl DeviceWrapper { - pub fn bench_transmit_bandwidth(&mut self) { - // 10 Gb - const MAX_SEND_BYTES: usize = 10 * GB; - let mut send_bytes: usize = 0; - let mut past_send_bytes: usize = 0; - let mut past_time = InterfaceWrapper::current_time(); - - // Send bytes - while send_bytes < MAX_SEND_BYTES { - if let Some(tx_token) = self.transmit(InterfaceWrapper::current_time()) { - AxNetTxToken::consume(tx_token, STANDARD_MTU, |tx_buf| { - tx_buf[0..12].fill(1); - // ether type: IPv4 - tx_buf[12..14].copy_from_slice(&[0x08, 0x00]); - tx_buf[14..STANDARD_MTU].fill(1); - }); - send_bytes += STANDARD_MTU; - } - - let current_time = InterfaceWrapper::current_time(); - if (current_time - past_time).secs() == 1 { - let gb = ((send_bytes - past_send_bytes) * 8) / GB; - let mb = (((send_bytes - past_send_bytes) * 8) % GB) / MB; - let gib = (send_bytes - past_send_bytes) / GB; - let mib = ((send_bytes - past_send_bytes) % GB) / MB; - info!("Transmit: {gib}.{mib:03}GBytes, Bandwidth: {gb}.{mb:03}Gbits/sec."); - past_time = current_time; - past_send_bytes = send_bytes; - } - } - } - - pub fn bench_receive_bandwidth(&mut self) { - // 10 Gb - const MAX_RECEIVE_BYTES: usize = 10 * GB; - let mut receive_bytes: usize = 0; - let mut past_receive_bytes: usize = 0; - let mut past_time = InterfaceWrapper::current_time(); - // Receive bytes - while receive_bytes < MAX_RECEIVE_BYTES { - if let Some(rx_token) = self.receive(InterfaceWrapper::current_time()) { - AxNetRxToken::consume(rx_token.0, |rx_buf| { - receive_bytes += rx_buf.len(); - }); - } - - let current_time = InterfaceWrapper::current_time(); - if (current_time - past_time).secs() == 1 { - let gb = ((receive_bytes - past_receive_bytes) * 8) / GB; - let mb = (((receive_bytes - past_receive_bytes) * 8) % GB) / MB; - let gib = (receive_bytes - past_receive_bytes) / GB; - let mib = ((receive_bytes - past_receive_bytes) % GB) / MB; - info!("Receive: {gib}.{mib:03}GBytes, Bandwidth: {gb}.{mb:03}Gbits/sec."); - past_time = current_time; - past_receive_bytes = receive_bytes; - } - } - } -} diff --git a/modules/axnet/src/smoltcp_impl/dns.rs b/modules/axnet/src/smoltcp_impl/dns.rs deleted file mode 100644 index 53dee35066..0000000000 --- a/modules/axnet/src/smoltcp_impl/dns.rs +++ /dev/null @@ -1,89 +0,0 @@ -use alloc::vec::Vec; -use axerrno::{AxError, AxResult, ax_err_type}; -use core::net::IpAddr; - -use smoltcp::iface::SocketHandle; -use smoltcp::socket::dns::{self, GetQueryResultError, StartQueryError}; -use smoltcp::wire::DnsQueryType; - -use super::{ETH0, SOCKET_SET, SocketSetWrapper}; - -/// A DNS socket. -struct DnsSocket { - handle: Option, -} - -impl DnsSocket { - #[allow(clippy::new_without_default)] - /// Creates a new DNS socket. - pub fn new() -> Self { - let socket = SocketSetWrapper::new_dns_socket(); - let handle = Some(SOCKET_SET.add(socket)); - Self { handle } - } - - #[allow(dead_code)] - /// Update the list of DNS servers, will replace all existing servers. - pub fn update_servers(self, servers: &[smoltcp::wire::IpAddress]) { - SOCKET_SET.with_socket_mut::(self.handle.unwrap(), |socket| { - socket.update_servers(servers) - }); - } - - /// Query a address with given DNS query type. - pub fn query(&self, name: &str, query_type: DnsQueryType) -> AxResult> { - // let local_addr = self.local_addr.unwrap_or_else(f); - let handle = self.handle.ok_or_else(|| ax_err_type!(InvalidInput))?; - let iface = Ð0.iface; - let query_handle = SOCKET_SET - .with_socket_mut::(handle, |socket| { - socket.start_query(iface.lock().context(), name, query_type) - }) - .map_err(|e| match e { - StartQueryError::NoFreeSlot => { - ax_err_type!(ResourceBusy, "socket query() failed: no free slot") - } - StartQueryError::InvalidName => { - ax_err_type!(InvalidInput, "socket query() failed: invalid name") - } - StartQueryError::NameTooLong => { - ax_err_type!(InvalidInput, "socket query() failed: too long name") - } - })?; - loop { - SOCKET_SET.poll_interfaces(); - match SOCKET_SET.with_socket_mut::(handle, |socket| { - socket.get_query_result(query_handle).map_err(|e| match e { - GetQueryResultError::Pending => AxError::WouldBlock, - GetQueryResultError::Failed => { - ax_err_type!(ConnectionRefused, "socket query() failed") - } - }) - }) { - Ok(n) => { - let mut res = Vec::with_capacity(n.capacity()); - for ip in n { - res.push(IpAddr::from(ip)) - } - return Ok(res); - } - Err(AxError::WouldBlock) => axtask::yield_now(), - Err(e) => return Err(e), - } - } - } -} - -impl Drop for DnsSocket { - fn drop(&mut self) { - if let Some(handle) = self.handle { - SOCKET_SET.remove(handle); - } - } -} - -/// Public function for DNS query. -pub fn dns_query(name: &str) -> AxResult> { - let socket = DnsSocket::new(); - socket.query(name, DnsQueryType::A) -} diff --git a/modules/axnet/src/smoltcp_impl/listen_table.rs b/modules/axnet/src/smoltcp_impl/listen_table.rs deleted file mode 100644 index 10794577cb..0000000000 --- a/modules/axnet/src/smoltcp_impl/listen_table.rs +++ /dev/null @@ -1,155 +0,0 @@ -use alloc::{boxed::Box, collections::VecDeque}; -use core::ops::{Deref, DerefMut}; - -use axerrno::{AxError, AxResult, ax_err}; -use axsync::Mutex; -use smoltcp::iface::{SocketHandle, SocketSet}; -use smoltcp::socket::tcp::{self, State}; -use smoltcp::wire::{IpAddress, IpEndpoint, IpListenEndpoint}; - -use super::{LISTEN_QUEUE_SIZE, SOCKET_SET, SocketSetWrapper}; - -const PORT_NUM: usize = 65536; - -struct ListenTableEntry { - listen_endpoint: IpListenEndpoint, - syn_queue: VecDeque, -} - -impl ListenTableEntry { - pub fn new(listen_endpoint: IpListenEndpoint) -> Self { - Self { - listen_endpoint, - syn_queue: VecDeque::with_capacity(LISTEN_QUEUE_SIZE), - } - } - - #[inline] - fn can_accept(&self, dst: IpAddress) -> bool { - match self.listen_endpoint.addr { - Some(addr) => addr == dst, - None => true, - } - } -} - -impl Drop for ListenTableEntry { - fn drop(&mut self) { - for &handle in &self.syn_queue { - SOCKET_SET.remove(handle); - } - } -} - -pub struct ListenTable { - tcp: Box<[Mutex>>]>, -} - -impl ListenTable { - pub fn new() -> Self { - let tcp = unsafe { - let mut buf = Box::new_uninit_slice(PORT_NUM); - for i in 0..PORT_NUM { - buf[i].write(Mutex::new(None)); - } - buf.assume_init() - }; - Self { tcp } - } - - pub fn can_listen(&self, port: u16) -> bool { - self.tcp[port as usize].lock().is_none() - } - - pub fn listen(&self, listen_endpoint: IpListenEndpoint) -> AxResult { - let port = listen_endpoint.port; - assert_ne!(port, 0); - let mut entry = self.tcp[port as usize].lock(); - if entry.is_none() { - *entry = Some(Box::new(ListenTableEntry::new(listen_endpoint))); - Ok(()) - } else { - ax_err!(AddrInUse, "socket listen() failed") - } - } - - pub fn unlisten(&self, port: u16) { - debug!("TCP socket unlisten on {port}"); - *self.tcp[port as usize].lock() = None; - } - - pub fn can_accept(&self, port: u16) -> AxResult { - if let Some(entry) = self.tcp[port as usize].lock().deref() { - Ok(entry.syn_queue.iter().any(|&handle| is_connected(handle))) - } else { - ax_err!(InvalidInput, "socket accept() failed: not listen") - } - } - - pub fn accept(&self, port: u16) -> AxResult<(SocketHandle, (IpEndpoint, IpEndpoint))> { - if let Some(entry) = self.tcp[port as usize].lock().deref_mut() { - let syn_queue = &mut entry.syn_queue; - let (idx, addr_tuple) = syn_queue - .iter() - .enumerate() - .find_map(|(idx, &handle)| { - is_connected(handle).then(|| (idx, get_addr_tuple(handle))) - }) - .ok_or(AxError::WouldBlock)?; // wait for connection - if idx > 0 { - warn!( - "slow SYN queue enumeration: index = {}, len = {}!", - idx, - syn_queue.len() - ); - } - let handle = syn_queue.swap_remove_front(idx).unwrap(); - Ok((handle, addr_tuple)) - } else { - ax_err!(InvalidInput, "socket accept() failed: not listen") - } - } - - pub fn incoming_tcp_packet( - &self, - src: IpEndpoint, - dst: IpEndpoint, - sockets: &mut SocketSet<'_>, - ) { - if let Some(entry) = self.tcp[dst.port as usize].lock().deref_mut() { - if !entry.can_accept(dst.addr) { - // not listening on this address - return; - } - if entry.syn_queue.len() >= LISTEN_QUEUE_SIZE { - // SYN queue is full, drop the packet - warn!("SYN queue overflow!"); - return; - } - let mut socket = SocketSetWrapper::new_tcp_socket(); - if socket.listen(entry.listen_endpoint).is_ok() { - let handle = sockets.add(socket); - debug!( - "TCP socket {}: prepare for connection {} -> {}", - handle, src, entry.listen_endpoint - ); - entry.syn_queue.push_back(handle); - } - } - } -} - -fn is_connected(handle: SocketHandle) -> bool { - SOCKET_SET.with_socket::(handle, |socket| { - !matches!(socket.state(), State::Listen | State::SynReceived) - }) -} - -fn get_addr_tuple(handle: SocketHandle) -> (IpEndpoint, IpEndpoint) { - SOCKET_SET.with_socket::(handle, |socket| { - ( - socket.local_endpoint().unwrap(), - socket.remote_endpoint().unwrap(), - ) - }) -} diff --git a/modules/axnet/src/smoltcp_impl/mod.rs b/modules/axnet/src/smoltcp_impl/mod.rs deleted file mode 100644 index f7f05ae28c..0000000000 --- a/modules/axnet/src/smoltcp_impl/mod.rs +++ /dev/null @@ -1,336 +0,0 @@ -mod addr; -mod bench; -mod dns; -mod listen_table; -mod tcp; -mod udp; - -use alloc::vec; -use core::cell::RefCell; -use core::ops::DerefMut; - -use axdriver::prelude::*; -use axhal::time::{NANOS_PER_MICROS, wall_time_nanos}; -use axsync::Mutex; -use lazyinit::LazyInit; -use smoltcp::iface::{Config, Interface, SocketHandle, SocketSet}; -use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken}; -use smoltcp::socket::{self, AnySocket}; -use smoltcp::time::Instant; -use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress, IpCidr}; - -use self::listen_table::ListenTable; - -pub use self::dns::dns_query; -pub use self::tcp::TcpSocket; -pub use self::udp::UdpSocket; - -macro_rules! env_or_default { - ($key:literal) => { - match option_env!($key) { - Some(val) => val, - None => "", - } - }; -} - -const IP: &str = env_or_default!("AX_IP"); -const GATEWAY: &str = env_or_default!("AX_GW"); -const DNS_SEVER: &str = "8.8.8.8"; -const IP_PREFIX: u8 = 24; - -const STANDARD_MTU: usize = 1500; - -const RANDOM_SEED: u64 = 0xA2CE_05A2_CE05_A2CE; - -const TCP_RX_BUF_LEN: usize = 64 * 1024; -const TCP_TX_BUF_LEN: usize = 64 * 1024; -const UDP_RX_BUF_LEN: usize = 64 * 1024; -const UDP_TX_BUF_LEN: usize = 64 * 1024; -const LISTEN_QUEUE_SIZE: usize = 512; - -static LISTEN_TABLE: LazyInit = LazyInit::new(); -static SOCKET_SET: LazyInit = LazyInit::new(); -static ETH0: LazyInit = LazyInit::new(); - -struct SocketSetWrapper<'a>(Mutex>); - -struct DeviceWrapper { - inner: RefCell, // use `RefCell` is enough since it's wrapped in `Mutex` in `InterfaceWrapper`. -} - -struct InterfaceWrapper { - name: &'static str, - ether_addr: EthernetAddress, - dev: Mutex, - iface: Mutex, -} - -impl<'a> SocketSetWrapper<'a> { - fn new() -> Self { - Self(Mutex::new(SocketSet::new(vec![]))) - } - - pub fn new_tcp_socket() -> socket::tcp::Socket<'a> { - let tcp_rx_buffer = socket::tcp::SocketBuffer::new(vec![0; TCP_RX_BUF_LEN]); - let tcp_tx_buffer = socket::tcp::SocketBuffer::new(vec![0; TCP_TX_BUF_LEN]); - socket::tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer) - } - - pub fn new_udp_socket() -> socket::udp::Socket<'a> { - let udp_rx_buffer = socket::udp::PacketBuffer::new( - vec![socket::udp::PacketMetadata::EMPTY; 8], - vec![0; UDP_RX_BUF_LEN], - ); - let udp_tx_buffer = socket::udp::PacketBuffer::new( - vec![socket::udp::PacketMetadata::EMPTY; 8], - vec![0; UDP_TX_BUF_LEN], - ); - socket::udp::Socket::new(udp_rx_buffer, udp_tx_buffer) - } - - pub fn new_dns_socket() -> socket::dns::Socket<'a> { - let server_addr = DNS_SEVER.parse().expect("invalid DNS server address"); - socket::dns::Socket::new(&[server_addr], vec![]) - } - - pub fn add>(&self, socket: T) -> SocketHandle { - let handle = self.0.lock().add(socket); - debug!("socket {handle}: created"); - handle - } - - pub fn with_socket, R, F>(&self, handle: SocketHandle, f: F) -> R - where - F: FnOnce(&T) -> R, - { - let set = self.0.lock(); - let socket = set.get(handle); - f(socket) - } - - pub fn with_socket_mut, R, F>(&self, handle: SocketHandle, f: F) -> R - where - F: FnOnce(&mut T) -> R, - { - let mut set = self.0.lock(); - let socket = set.get_mut(handle); - f(socket) - } - - pub fn poll_interfaces(&self) { - ETH0.poll(&self.0); - } - - pub fn remove(&self, handle: SocketHandle) { - self.0.lock().remove(handle); - debug!("socket {handle}: destroyed"); - } -} - -impl InterfaceWrapper { - fn new(name: &'static str, dev: AxNetDevice, ether_addr: EthernetAddress) -> Self { - let mut config = Config::new(HardwareAddress::Ethernet(ether_addr)); - config.random_seed = RANDOM_SEED; - - let mut dev = DeviceWrapper::new(dev); - let iface = Mutex::new(Interface::new(config, &mut dev, Self::current_time())); - Self { - name, - ether_addr, - dev: Mutex::new(dev), - iface, - } - } - - fn current_time() -> Instant { - Instant::from_micros_const((wall_time_nanos() / NANOS_PER_MICROS) as i64) - } - - pub fn name(&self) -> &str { - self.name - } - - pub fn ethernet_address(&self) -> EthernetAddress { - self.ether_addr - } - - pub fn setup_ip_addr(&self, ip: IpAddress, prefix_len: u8) { - let mut iface = self.iface.lock(); - iface.update_ip_addrs(|ip_addrs| { - ip_addrs.push(IpCidr::new(ip, prefix_len)).unwrap(); - }); - } - - pub fn setup_gateway(&self, gateway: IpAddress) { - let mut iface = self.iface.lock(); - match gateway { - IpAddress::Ipv4(v4) => iface.routes_mut().add_default_ipv4_route(v4).unwrap(), - IpAddress::Ipv6(v6) => iface.routes_mut().add_default_ipv6_route(v6).unwrap(), - }; - } - - pub fn poll(&self, sockets: &Mutex) { - let mut dev = self.dev.lock(); - let mut iface = self.iface.lock(); - let mut sockets = sockets.lock(); - let timestamp = Self::current_time(); - iface.poll(timestamp, dev.deref_mut(), &mut sockets); - } -} - -impl DeviceWrapper { - fn new(inner: AxNetDevice) -> Self { - Self { - inner: RefCell::new(inner), - } - } -} - -impl Device for DeviceWrapper { - type RxToken<'a> - = AxNetRxToken<'a> - where - Self: 'a; - type TxToken<'a> - = AxNetTxToken<'a> - where - Self: 'a; - - fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { - let mut dev = self.inner.borrow_mut(); - if let Err(e) = dev.recycle_tx_buffers() { - warn!("recycle_tx_buffers failed: {e:?}"); - return None; - } - - if !dev.can_transmit() { - return None; - } - let rx_buf = match dev.receive() { - Ok(buf) => buf, - Err(err) => { - if !matches!(err, DevError::Again) { - warn!("receive failed: {err:?}"); - } - return None; - } - }; - Some((AxNetRxToken(&self.inner, rx_buf), AxNetTxToken(&self.inner))) - } - - fn transmit(&mut self, _timestamp: Instant) -> Option> { - let mut dev = self.inner.borrow_mut(); - if let Err(e) = dev.recycle_tx_buffers() { - warn!("recycle_tx_buffers failed: {e:?}"); - return None; - } - if dev.can_transmit() { - Some(AxNetTxToken(&self.inner)) - } else { - None - } - } - - fn capabilities(&self) -> DeviceCapabilities { - let mut caps = DeviceCapabilities::default(); - caps.max_transmission_unit = 1514; - caps.max_burst_size = None; - caps.medium = Medium::Ethernet; - caps - } -} - -struct AxNetRxToken<'a>(&'a RefCell, NetBufPtr); -struct AxNetTxToken<'a>(&'a RefCell); - -impl RxToken for AxNetRxToken<'_> { - fn preprocess(&self, sockets: &mut SocketSet<'_>) { - snoop_tcp_packet(self.1.packet(), sockets).ok(); - } - - fn consume(self, f: F) -> R - where - F: FnOnce(&[u8]) -> R, - { - let rx_buf = self.1; - trace!( - "RECV {} bytes: {:02X?}", - rx_buf.packet_len(), - rx_buf.packet() - ); - let result = f(rx_buf.packet()); - self.0.borrow_mut().recycle_rx_buffer(rx_buf).unwrap(); - result - } -} - -impl TxToken for AxNetTxToken<'_> { - fn consume(self, len: usize, f: F) -> R - where - F: FnOnce(&mut [u8]) -> R, - { - let mut dev = self.0.borrow_mut(); - let mut tx_buf = dev.alloc_tx_buffer(len).unwrap(); - let ret = f(tx_buf.packet_mut()); - trace!("SEND {} bytes: {:02X?}", len, tx_buf.packet()); - dev.transmit(tx_buf).unwrap(); - ret - } -} - -fn snoop_tcp_packet(buf: &[u8], sockets: &mut SocketSet<'_>) -> Result<(), smoltcp::wire::Error> { - use smoltcp::wire::{EthernetFrame, IpProtocol, Ipv4Packet, TcpPacket}; - - let ether_frame = EthernetFrame::new_checked(buf)?; - let ipv4_packet = Ipv4Packet::new_checked(ether_frame.payload())?; - - if ipv4_packet.next_header() == IpProtocol::Tcp { - let tcp_packet = TcpPacket::new_checked(ipv4_packet.payload())?; - let src_addr = (ipv4_packet.src_addr(), tcp_packet.src_port()).into(); - let dst_addr = (ipv4_packet.dst_addr(), tcp_packet.dst_port()).into(); - let is_first = tcp_packet.syn() && !tcp_packet.ack(); - if is_first { - // create a socket for the first incoming TCP packet, as the later accept() returns. - LISTEN_TABLE.incoming_tcp_packet(src_addr, dst_addr, sockets); - } - } - Ok(()) -} - -/// Poll the network stack. -/// -/// It may receive packets from the NIC and process them, and transmit queued -/// packets to the NIC. -pub fn poll_interfaces() { - SOCKET_SET.poll_interfaces(); -} - -/// Benchmark raw socket transmit bandwidth. -pub fn bench_transmit() { - ETH0.dev.lock().bench_transmit_bandwidth(); -} - -/// Benchmark raw socket receive bandwidth. -pub fn bench_receive() { - ETH0.dev.lock().bench_receive_bandwidth(); -} - -pub(crate) fn init(net_dev: AxNetDevice) { - let ether_addr = EthernetAddress(net_dev.mac_address().0); - let eth0 = InterfaceWrapper::new("eth0", net_dev, ether_addr); - - let ip = IP.parse().expect("invalid IP address"); - let gateway = GATEWAY.parse().expect("invalid gateway IP address"); - eth0.setup_ip_addr(ip, IP_PREFIX); - eth0.setup_gateway(gateway); - - ETH0.init_once(eth0); - SOCKET_SET.init_once(SocketSetWrapper::new()); - LISTEN_TABLE.init_once(ListenTable::new()); - - info!("created net interface {:?}:", ETH0.name()); - info!(" ether: {}", ETH0.ethernet_address()); - info!(" ip: {ip}/{IP_PREFIX}"); - info!(" gateway: {gateway}"); -} diff --git a/modules/axnet/src/smoltcp_impl/tcp.rs b/modules/axnet/src/smoltcp_impl/tcp.rs deleted file mode 100644 index 4cd3926835..0000000000 --- a/modules/axnet/src/smoltcp_impl/tcp.rs +++ /dev/null @@ -1,569 +0,0 @@ -use core::cell::UnsafeCell; -use core::net::SocketAddr; -use core::sync::atomic::{AtomicBool, AtomicU8, Ordering}; - -use axerrno::{AxError, AxResult, ax_err, ax_err_type}; -use axio::PollState; -use axsync::Mutex; - -use smoltcp::iface::SocketHandle; -use smoltcp::socket::tcp::{self, ConnectError, State}; -use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; - -use super::addr::UNSPECIFIED_ENDPOINT; -use super::{ETH0, LISTEN_TABLE, SOCKET_SET, SocketSetWrapper}; - -// State transitions: -// CLOSED -(connect)-> BUSY -> CONNECTING -> CONNECTED -(shutdown)-> BUSY -> CLOSED -// | -// |-(listen)-> BUSY -> LISTENING -(shutdown)-> BUSY -> CLOSED -// | -// -(bind)-> BUSY -> CLOSED -const STATE_CLOSED: u8 = 0; -const STATE_BUSY: u8 = 1; -const STATE_CONNECTING: u8 = 2; -const STATE_CONNECTED: u8 = 3; -const STATE_LISTENING: u8 = 4; - -/// A TCP socket that provides POSIX-like APIs. -/// -/// - [`connect`] is for TCP clients. -/// - [`bind`], [`listen`], and [`accept`] are for TCP servers. -/// - Other methods are for both TCP clients and servers. -/// -/// [`connect`]: TcpSocket::connect -/// [`bind`]: TcpSocket::bind -/// [`listen`]: TcpSocket::listen -/// [`accept`]: TcpSocket::accept -pub struct TcpSocket { - state: AtomicU8, - handle: UnsafeCell>, - local_addr: UnsafeCell, - peer_addr: UnsafeCell, - nonblock: AtomicBool, -} - -unsafe impl Sync for TcpSocket {} - -impl Default for TcpSocket { - fn default() -> Self { - Self::new() - } -} - -impl TcpSocket { - /// Creates a new TCP socket. - pub const fn new() -> Self { - Self { - state: AtomicU8::new(STATE_CLOSED), - handle: UnsafeCell::new(None), - local_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT), - peer_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT), - nonblock: AtomicBool::new(false), - } - } - - /// Creates a new TCP socket that is already connected. - const fn new_connected( - handle: SocketHandle, - local_addr: IpEndpoint, - peer_addr: IpEndpoint, - ) -> Self { - Self { - state: AtomicU8::new(STATE_CONNECTED), - handle: UnsafeCell::new(Some(handle)), - local_addr: UnsafeCell::new(local_addr), - peer_addr: UnsafeCell::new(peer_addr), - nonblock: AtomicBool::new(false), - } - } - - /// Returns the local address and port, or - /// [`Err(NotConnected)`](AxError::NotConnected) if not connected. - pub fn local_addr(&self) -> AxResult { - match self.get_state() { - STATE_CONNECTED | STATE_LISTENING => { - Ok(SocketAddr::from(unsafe { self.local_addr.get().read() })) - } - _ => Err(AxError::NotConnected), - } - } - - /// Returns the remote address and port, or - /// [`Err(NotConnected)`](AxError::NotConnected) if not connected. - pub fn peer_addr(&self) -> AxResult { - match self.get_state() { - STATE_CONNECTED | STATE_LISTENING => { - Ok(SocketAddr::from(unsafe { self.peer_addr.get().read() })) - } - _ => Err(AxError::NotConnected), - } - } - - /// Returns whether this socket is in nonblocking mode. - #[inline] - pub fn is_nonblocking(&self) -> bool { - self.nonblock.load(Ordering::Acquire) - } - - /// Moves this TCP stream into or out of nonblocking mode. - /// - /// This will result in `read`, `write`, `recv` and `send` operations - /// becoming nonblocking, i.e., immediately returning from their calls. - /// If the IO operation is successful, `Ok` is returned and no further - /// action is required. If the IO operation could not be completed and needs - /// to be retried, an error with kind [`Err(WouldBlock)`](AxError::WouldBlock) is - /// returned. - #[inline] - pub fn set_nonblocking(&self, nonblocking: bool) { - self.nonblock.store(nonblocking, Ordering::Release); - } - - /// Connects to the given address and port. - /// - /// The local port is generated automatically. - pub fn connect(&self, remote_addr: SocketAddr) -> AxResult { - self.update_state(STATE_CLOSED, STATE_CONNECTING, || { - // SAFETY: no other threads can read or write these fields. - let handle = unsafe { self.handle.get().read() } - .unwrap_or_else(|| SOCKET_SET.add(SocketSetWrapper::new_tcp_socket())); - - // TODO: check remote addr unreachable - let bound_endpoint = self.bound_endpoint()?; - let iface = Ð0.iface; - let (local_endpoint, remote_endpoint) = SOCKET_SET - .with_socket_mut::(handle, |socket| { - socket - .connect(iface.lock().context(), remote_addr, bound_endpoint) - .or_else(|e| match e { - ConnectError::InvalidState => { - ax_err!(BadState, "socket connect() failed") - } - ConnectError::Unaddressable => { - ax_err!(ConnectionRefused, "socket connect() failed") - } - })?; - AxResult::Ok(( - socket.local_endpoint().unwrap(), - socket.remote_endpoint().unwrap(), - )) - })?; - unsafe { - // SAFETY: no other threads can read or write these fields as we - // have changed the state to `BUSY`. - self.local_addr.get().write(local_endpoint); - self.peer_addr.get().write(remote_endpoint); - self.handle.get().write(Some(handle)); - } - Ok(()) - }) - .unwrap_or_else(|_| ax_err!(AlreadyExists, "socket connect() failed: already connected"))?; // EISCONN - - // Here our state must be `CONNECTING`, and only one thread can run here. - if self.is_nonblocking() { - Err(AxError::WouldBlock) - } else { - self.block_on(|| { - let PollState { writable, .. } = self.poll_connect()?; - if !writable { - Err(AxError::WouldBlock) - } else if self.get_state() == STATE_CONNECTED { - Ok(()) - } else { - ax_err!(ConnectionRefused, "socket connect() failed") - } - }) - } - } - - /// Binds an unbound socket to the given address and port. - /// - /// If the given port is 0, it generates one automatically. - /// - /// It's must be called before [`listen`](Self::listen) and - /// [`accept`](Self::accept). - pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult { - self.update_state(STATE_CLOSED, STATE_CLOSED, || { - // TODO: check addr is available - if local_addr.port() == 0 { - local_addr.set_port(get_ephemeral_port()?); - } - // SAFETY: no other threads can read or write `self.local_addr` as we - // have changed the state to `BUSY`. - unsafe { - let old = self.local_addr.get().read(); - if old != UNSPECIFIED_ENDPOINT { - return ax_err!(InvalidInput, "socket bind() failed: already bound"); - } - self.local_addr.get().write(IpEndpoint::from(local_addr)); - } - Ok(()) - }) - .unwrap_or_else(|_| ax_err!(InvalidInput, "socket bind() failed: already bound")) - } - - /// Starts listening on the bound address and port. - /// - /// It's must be called after [`bind`](Self::bind) and before - /// [`accept`](Self::accept). - pub fn listen(&self) -> AxResult { - self.update_state(STATE_CLOSED, STATE_LISTENING, || { - let bound_endpoint = self.bound_endpoint()?; - unsafe { - (*self.local_addr.get()).port = bound_endpoint.port; - } - LISTEN_TABLE.listen(bound_endpoint)?; - debug!("TCP socket listening on {bound_endpoint}"); - Ok(()) - }) - .unwrap_or(Ok(())) // ignore simultaneous `listen`s. - } - - /// Accepts a new connection. - /// - /// This function will block the calling thread until a new TCP connection - /// is established. When established, a new [`TcpSocket`] is returned. - /// - /// It's must be called after [`bind`](Self::bind) and [`listen`](Self::listen). - pub fn accept(&self) -> AxResult { - if !self.is_listening() { - return ax_err!(InvalidInput, "socket accept() failed: not listen"); - } - - // SAFETY: `self.local_addr` should be initialized after `bind()`. - let local_port = unsafe { self.local_addr.get().read().port }; - self.block_on(|| { - let (handle, (local_addr, peer_addr)) = LISTEN_TABLE.accept(local_port)?; - debug!("TCP socket accepted a new connection {peer_addr}"); - Ok(TcpSocket::new_connected(handle, local_addr, peer_addr)) - }) - } - - /// Close the connection. - pub fn shutdown(&self) -> AxResult { - // stream - self.update_state(STATE_CONNECTED, STATE_CLOSED, || { - // SAFETY: `self.handle` should be initialized in a connected socket, and - // no other threads can read or write it. - let handle = unsafe { self.handle.get().read().unwrap() }; - SOCKET_SET.with_socket_mut::(handle, |socket| { - debug!("TCP socket {handle}: shutting down"); - socket.close(); - }); - unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; // clear bound address - SOCKET_SET.poll_interfaces(); - Ok(()) - }) - .unwrap_or(Ok(()))?; - - // listener - self.update_state(STATE_LISTENING, STATE_CLOSED, || { - // SAFETY: `self.local_addr` should be initialized in a listening socket, - // and no other threads can read or write it. - let local_port = unsafe { self.local_addr.get().read().port }; - unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; // clear bound address - LISTEN_TABLE.unlisten(local_port); - SOCKET_SET.poll_interfaces(); - Ok(()) - }) - .unwrap_or(Ok(()))?; - - // ignore for other states - Ok(()) - } - - /// Receives data from the socket, stores it in the given buffer. - pub fn recv(&self, buf: &mut [u8]) -> AxResult { - if self.is_connecting() { - return Err(AxError::WouldBlock); - } else if !self.is_connected() { - return ax_err!(NotConnected, "socket recv() failed"); - } - - // SAFETY: `self.handle` should be initialized in a connected socket. - let handle = unsafe { self.handle.get().read().unwrap() }; - self.block_on(|| { - SOCKET_SET.with_socket_mut::(handle, |socket| { - if !socket.is_active() { - // not open - ax_err!(ConnectionRefused, "socket recv() failed") - } else if !socket.may_recv() { - // connection closed - Ok(0) - } else if socket.recv_queue() > 0 { - // data available - // TODO: use socket.recv(|buf| {...}) - let len = socket - .recv_slice(buf) - .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?; - Ok(len) - } else { - // no more data - Err(AxError::WouldBlock) - } - }) - }) - } - - /// Transmits data in the given buffer. - pub fn send(&self, buf: &[u8]) -> AxResult { - if self.is_connecting() { - return Err(AxError::WouldBlock); - } else if !self.is_connected() { - return ax_err!(NotConnected, "socket send() failed"); - } - - // SAFETY: `self.handle` should be initialized in a connected socket. - let handle = unsafe { self.handle.get().read().unwrap() }; - self.block_on(|| { - SOCKET_SET.with_socket_mut::(handle, |socket| { - if !socket.is_active() || !socket.may_send() { - // closed by remote - ax_err!(ConnectionReset, "socket send() failed") - } else if socket.can_send() { - // connected, and the tx buffer is not full - // TODO: use socket.send(|buf| {...}) - let len = socket - .send_slice(buf) - .map_err(|_| ax_err_type!(BadState, "socket send() failed"))?; - Ok(len) - } else { - // tx buffer is full - Err(AxError::WouldBlock) - } - }) - }) - } - - /// Whether the socket is readable or writable. - pub fn poll(&self) -> AxResult { - match self.get_state() { - STATE_CONNECTING => self.poll_connect(), - STATE_CONNECTED => self.poll_stream(), - STATE_LISTENING => self.poll_listener(), - _ => Ok(PollState { - readable: false, - writable: false, - }), - } - } - - /// Checks if Nagle's algorithm is enabled for this TCP socket. - pub fn nodelay(&self) -> AxResult { - if let Some(h) = unsafe { self.handle.get().read() } { - Ok(SOCKET_SET.with_socket::(h, |socket| socket.nagle_enabled())) - } else { - ax_err!(NotConnected, "socket is not connected") - } - } - - /// Enables or disables Nagle's algorithm for this TCP socket. - pub fn set_nodelay(&self, enabled: bool) -> AxResult<()> { - if let Some(h) = unsafe { self.handle.get().read() } { - SOCKET_SET.with_socket_mut::(h, |socket| { - socket.set_nagle_enabled(enabled); - }); - Ok(()) - } else { - ax_err!(NotConnected, "socket is not connected") - } - } - - /// Returns the maximum capacity of the receive buffer in bytes. - pub fn recv_capacity(&self) -> AxResult { - if let Some(h) = unsafe { self.handle.get().read() } { - Ok(SOCKET_SET.with_socket::(h, |socket| socket.recv_capacity())) - } else { - ax_err!(NotConnected, "socket is not connected") - } - } - - /// Returns the maximum capacity of the send buffer in bytes. - pub fn send_capacity(&self) -> AxResult { - if let Some(h) = unsafe { self.handle.get().read() } { - Ok(SOCKET_SET.with_socket::(h, |socket| socket.send_capacity())) - } else { - ax_err!(NotConnected, "socket is not connected") - } - } -} - -/// Private methods -impl TcpSocket { - #[inline] - fn get_state(&self) -> u8 { - self.state.load(Ordering::Acquire) - } - - #[inline] - fn set_state(&self, state: u8) { - self.state.store(state, Ordering::Release); - } - - /// Update the state of the socket atomically. - /// - /// If the current state is `expect`, it first changes the state to `STATE_BUSY`, - /// then calls the given function. If the function returns `Ok`, it changes the - /// state to `new`, otherwise it changes the state back to `expect`. - /// - /// It returns `Ok` if the current state is `expect`, otherwise it returns - /// the current state in `Err`. - fn update_state(&self, expect: u8, new: u8, f: F) -> Result, u8> - where - F: FnOnce() -> AxResult, - { - match self - .state - .compare_exchange(expect, STATE_BUSY, Ordering::Acquire, Ordering::Acquire) - { - Ok(_) => { - let res = f(); - if res.is_ok() { - self.set_state(new); - } else { - self.set_state(expect); - } - Ok(res) - } - Err(old) => Err(old), - } - } - - #[inline] - fn is_connecting(&self) -> bool { - self.get_state() == STATE_CONNECTING - } - - #[inline] - fn is_connected(&self) -> bool { - self.get_state() == STATE_CONNECTED - } - - #[inline] - fn is_listening(&self) -> bool { - self.get_state() == STATE_LISTENING - } - - fn bound_endpoint(&self) -> AxResult { - // SAFETY: no other threads can read or write `self.local_addr`. - let local_addr = unsafe { self.local_addr.get().read() }; - let port = if local_addr.port != 0 { - local_addr.port - } else { - get_ephemeral_port()? - }; - assert_ne!(port, 0); - let addr = if !local_addr.addr.is_unspecified() { - Some(local_addr.addr) - } else { - None - }; - Ok(IpListenEndpoint { addr, port }) - } - - fn poll_connect(&self) -> AxResult { - // SAFETY: `self.handle` should be initialized above. - let handle = unsafe { self.handle.get().read().unwrap() }; - let writable = - SOCKET_SET.with_socket::(handle, |socket| match socket.state() { - State::SynSent => false, // wait for connection - State::Established => { - self.set_state(STATE_CONNECTED); // connected - debug!( - "TCP socket {}: connected to {}", - handle, - socket.remote_endpoint().unwrap(), - ); - true - } - _ => { - unsafe { - self.local_addr.get().write(UNSPECIFIED_ENDPOINT); - self.peer_addr.get().write(UNSPECIFIED_ENDPOINT); - } - self.set_state(STATE_CLOSED); // connection failed - true - } - }); - Ok(PollState { - readable: false, - writable, - }) - } - - fn poll_stream(&self) -> AxResult { - // SAFETY: `self.handle` should be initialized in a connected socket. - let handle = unsafe { self.handle.get().read().unwrap() }; - SOCKET_SET.with_socket::(handle, |socket| { - Ok(PollState { - readable: !socket.may_recv() || socket.can_recv(), - writable: !socket.may_send() || socket.can_send(), - }) - }) - } - - fn poll_listener(&self) -> AxResult { - // SAFETY: `self.local_addr` should be initialized in a listening socket. - let local_addr = unsafe { self.local_addr.get().read() }; - Ok(PollState { - readable: LISTEN_TABLE.can_accept(local_addr.port)?, - writable: false, - }) - } - - /// Block the current thread until the given function completes or fails. - /// - /// If the socket is non-blocking, it calls the function once and returns - /// immediately. Otherwise, it may call the function multiple times if it - /// returns [`Err(WouldBlock)`](AxError::WouldBlock). - fn block_on(&self, mut f: F) -> AxResult - where - F: FnMut() -> AxResult, - { - if self.is_nonblocking() { - f() - } else { - loop { - SOCKET_SET.poll_interfaces(); - match f() { - Ok(t) => return Ok(t), - Err(AxError::WouldBlock) => axtask::yield_now(), - Err(e) => return Err(e), - } - } - } - } -} - -impl Drop for TcpSocket { - fn drop(&mut self) { - self.shutdown().ok(); - // Safe because we have mut reference to `self`. - if let Some(handle) = unsafe { self.handle.get().read() } { - SOCKET_SET.remove(handle); - } - } -} - -fn get_ephemeral_port() -> AxResult { - const PORT_START: u16 = 0xc000; - const PORT_END: u16 = 0xffff; - static CURR: Mutex = Mutex::new(PORT_START); - - let mut curr = CURR.lock(); - let mut tries = 0; - // TODO: more robust - while tries <= PORT_END - PORT_START { - let port = *curr; - if *curr == PORT_END { - *curr = PORT_START; - } else { - *curr += 1; - } - if LISTEN_TABLE.can_listen(port) { - return Ok(port); - } - tries += 1; - } - ax_err!(AddrInUse, "no avaliable ports!") -} diff --git a/modules/axnet/src/smoltcp_impl/udp.rs b/modules/axnet/src/smoltcp_impl/udp.rs deleted file mode 100644 index c9b57d6b51..0000000000 --- a/modules/axnet/src/smoltcp_impl/udp.rs +++ /dev/null @@ -1,295 +0,0 @@ -use core::net::SocketAddr; -use core::sync::atomic::{AtomicBool, Ordering}; - -use axerrno::{AxError, AxResult, ax_err, ax_err_type}; -use axio::PollState; -use axsync::Mutex; -use spin::RwLock; - -use smoltcp::iface::SocketHandle; -use smoltcp::socket::udp::{self, BindError, SendError}; -use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; - -use super::addr::UNSPECIFIED_ENDPOINT; -use super::{SOCKET_SET, SocketSetWrapper}; - -/// A UDP socket that provides POSIX-like APIs. -pub struct UdpSocket { - handle: SocketHandle, - local_addr: RwLock>, - peer_addr: RwLock>, - nonblock: AtomicBool, -} - -impl UdpSocket { - /// Creates a new UDP socket. - #[allow(clippy::new_without_default)] - pub fn new() -> Self { - let socket = SocketSetWrapper::new_udp_socket(); - let handle = SOCKET_SET.add(socket); - Self { - handle, - local_addr: RwLock::new(None), - peer_addr: RwLock::new(None), - nonblock: AtomicBool::new(false), - } - } - - /// Returns the local address and port, or - /// [`Err(NotConnected)`](AxError::NotConnected) if not connected. - pub fn local_addr(&self) -> AxResult { - match self.local_addr.try_read() { - Some(addr) => addr.map(Into::into).ok_or(AxError::NotConnected), - None => Err(AxError::NotConnected), - } - } - - /// Returns the remote address and port, or - /// [`Err(NotConnected)`](AxError::NotConnected) if not connected. - pub fn peer_addr(&self) -> AxResult { - self.remote_endpoint().map(Into::into) - } - - /// Returns whether this socket is in nonblocking mode. - #[inline] - pub fn is_nonblocking(&self) -> bool { - self.nonblock.load(Ordering::Acquire) - } - - /// Moves this UDP socket into or out of nonblocking mode. - /// - /// This will result in `recv`, `recv_from`, `send`, and `send_to` - /// operations becoming nonblocking, i.e., immediately returning from their - /// calls. If the IO operation is successful, `Ok` is returned and no - /// further action is required. If the IO operation could not be completed - /// and needs to be retried, an error with kind - /// [`Err(WouldBlock)`](AxError::WouldBlock) is returned. - #[inline] - pub fn set_nonblocking(&self, nonblocking: bool) { - self.nonblock.store(nonblocking, Ordering::Release); - } - - /// Binds an unbound socket to the given address and port. - /// - /// It's must be called before [`send_to`](Self::send_to) and - /// [`recv_from`](Self::recv_from). - pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult { - let mut self_local_addr = self.local_addr.write(); - - if local_addr.port() == 0 { - local_addr.set_port(get_ephemeral_port()?); - } - if self_local_addr.is_some() { - return ax_err!(InvalidInput, "socket bind() failed: already bound"); - } - - let local_endpoint = IpEndpoint::from(local_addr); - let endpoint = IpListenEndpoint { - addr: (!local_endpoint.addr.is_unspecified()).then_some(local_endpoint.addr), - port: local_endpoint.port, - }; - SOCKET_SET.with_socket_mut::(self.handle, |socket| { - socket.bind(endpoint).or_else(|e| match e { - BindError::InvalidState => ax_err!(AlreadyExists, "socket bind() failed"), - BindError::Unaddressable => ax_err!(InvalidInput, "socket bind() failed"), - }) - })?; - - *self_local_addr = Some(local_endpoint); - debug!("UDP socket {}: bound on {}", self.handle, endpoint); - Ok(()) - } - - /// Sends data on the socket to the given address. On success, returns the - /// number of bytes written. - pub fn send_to(&self, buf: &[u8], remote_addr: SocketAddr) -> AxResult { - if remote_addr.port() == 0 || remote_addr.ip().is_unspecified() { - return ax_err!(InvalidInput, "socket send_to() failed: invalid address"); - } - self.send_impl(buf, IpEndpoint::from(remote_addr)) - } - - /// Receives a single datagram message on the socket. On success, returns - /// the number of bytes read and the origin. - pub fn recv_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> { - self.recv_impl(|socket| match socket.recv_slice(buf) { - Ok((len, meta)) => Ok((len, SocketAddr::from(meta.endpoint))), - Err(_) => ax_err!(BadState, "socket recv_from() failed"), - }) - } - - /// Receives a single datagram message on the socket, without removing it from - /// the queue. On success, returns the number of bytes read and the origin. - pub fn peek_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> { - self.recv_impl(|socket| match socket.peek_slice(buf) { - Ok((len, meta)) => Ok((len, SocketAddr::from(meta.endpoint))), - Err(_) => ax_err!(BadState, "socket recv_from() failed"), - }) - } - - /// Connects this UDP socket to a remote address, allowing the `send` and - /// `recv` to be used to send data and also applies filters to only receive - /// data from the specified address. - /// - /// The local port will be generated automatically if the socket is not bound. - /// It's must be called before [`send`](Self::send) and - /// [`recv`](Self::recv). - pub fn connect(&self, addr: SocketAddr) -> AxResult { - let mut self_peer_addr = self.peer_addr.write(); - - if self.local_addr.read().is_none() { - self.bind(SocketAddr::from(UNSPECIFIED_ENDPOINT))?; - } - - *self_peer_addr = Some(IpEndpoint::from(addr)); - debug!("UDP socket {}: connected to {}", self.handle, addr); - Ok(()) - } - - /// Sends data on the socket to the remote address to which it is connected. - pub fn send(&self, buf: &[u8]) -> AxResult { - let remote_endpoint = self.remote_endpoint()?; - self.send_impl(buf, remote_endpoint) - } - - /// Receives a single datagram message on the socket from the remote address - /// to which it is connected. On success, returns the number of bytes read. - pub fn recv(&self, buf: &mut [u8]) -> AxResult { - let remote_endpoint = self.remote_endpoint()?; - self.recv_impl(|socket| { - let (len, meta) = socket - .recv_slice(buf) - .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?; - if !remote_endpoint.addr.is_unspecified() && remote_endpoint.addr != meta.endpoint.addr - { - return Err(AxError::WouldBlock); - } - if remote_endpoint.port != 0 && remote_endpoint.port != meta.endpoint.port { - return Err(AxError::WouldBlock); - } - Ok(len) - }) - } - - /// Close the socket. - pub fn shutdown(&self) -> AxResult { - SOCKET_SET.with_socket_mut::(self.handle, |socket| { - debug!("UDP socket {}: shutting down", self.handle); - socket.close(); - }); - SOCKET_SET.poll_interfaces(); - Ok(()) - } - - /// Whether the socket is readable or writable. - pub fn poll(&self) -> AxResult { - if self.local_addr.read().is_none() { - return Ok(PollState { - readable: false, - writable: false, - }); - } - SOCKET_SET.with_socket_mut::(self.handle, |socket| { - Ok(PollState { - readable: socket.can_recv(), - writable: socket.can_send(), - }) - }) - } -} - -/// Private methods -impl UdpSocket { - fn remote_endpoint(&self) -> AxResult { - match self.peer_addr.try_read() { - Some(addr) => addr.ok_or(AxError::NotConnected), - None => Err(AxError::NotConnected), - } - } - - fn send_impl(&self, buf: &[u8], remote_endpoint: IpEndpoint) -> AxResult { - if self.local_addr.read().is_none() { - return ax_err!(NotConnected, "socket send() failed"); - } - - self.block_on(|| { - SOCKET_SET.with_socket_mut::(self.handle, |socket| { - if socket.can_send() { - socket - .send_slice(buf, remote_endpoint) - .map_err(|e| match e { - SendError::BufferFull => AxError::WouldBlock, - SendError::Unaddressable => { - ax_err_type!(ConnectionRefused, "socket send() failed") - } - })?; - Ok(buf.len()) - } else { - // tx buffer is full - Err(AxError::WouldBlock) - } - }) - }) - } - - fn recv_impl(&self, mut op: F) -> AxResult - where - F: FnMut(&mut udp::Socket) -> AxResult, - { - if self.local_addr.read().is_none() { - return ax_err!(NotConnected, "socket send() failed"); - } - - self.block_on(|| { - SOCKET_SET.with_socket_mut::(self.handle, |socket| { - if socket.can_recv() { - // data available - op(socket) - } else { - // no more data - Err(AxError::WouldBlock) - } - }) - }) - } - - fn block_on(&self, mut f: F) -> AxResult - where - F: FnMut() -> AxResult, - { - if self.is_nonblocking() { - f() - } else { - loop { - SOCKET_SET.poll_interfaces(); - match f() { - Ok(t) => return Ok(t), - Err(AxError::WouldBlock) => axtask::yield_now(), - Err(e) => return Err(e), - } - } - } - } -} - -impl Drop for UdpSocket { - fn drop(&mut self) { - self.shutdown().ok(); - SOCKET_SET.remove(self.handle); - } -} - -fn get_ephemeral_port() -> AxResult { - const PORT_START: u16 = 0xc000; - const PORT_END: u16 = 0xffff; - static CURR: Mutex = Mutex::new(PORT_START); - let mut curr = CURR.lock(); - - let port = *curr; - if *curr == PORT_END { - *curr = PORT_START; - } else { - *curr += 1; - } - Ok(port) -} diff --git a/modules/axnet/src/socket.rs b/modules/axnet/src/socket.rs new file mode 100644 index 0000000000..71041f2513 --- /dev/null +++ b/modules/axnet/src/socket.rs @@ -0,0 +1,195 @@ +use alloc::{boxed::Box, vec::Vec}; +use core::{ + any::Any, + fmt::{self, Debug}, + net::SocketAddr, + task::Context, +}; + +#[cfg(feature = "vsock")] +use axdriver::prelude::VsockAddr; +use axerrno::{AxError, AxResult, LinuxError}; +use axio::{Buf, BufMut}; +use axpoll::{IoEvents, Pollable}; +use bitflags::bitflags; +use enum_dispatch::enum_dispatch; + +#[cfg(feature = "vsock")] +use crate::vsock::VsockSocket; +use crate::{ + options::{Configurable, GetSocketOption, SetSocketOption}, + tcp::TcpSocket, + udp::UdpSocket, + unix::{UnixSocket, UnixSocketAddr}, +}; + +#[derive(Clone, Debug)] +pub enum SocketAddrEx { + Ip(SocketAddr), + Unix(UnixSocketAddr), + #[cfg(feature = "vsock")] + Vsock(VsockAddr), +} + +impl SocketAddrEx { + pub fn into_ip(self) -> AxResult { + match self { + SocketAddrEx::Ip(addr) => Ok(addr), + SocketAddrEx::Unix(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), + #[cfg(feature = "vsock")] + SocketAddrEx::Vsock(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), + } + } + + pub fn into_unix(self) -> AxResult { + match self { + SocketAddrEx::Unix(addr) => Ok(addr), + SocketAddrEx::Ip(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), + #[cfg(feature = "vsock")] + SocketAddrEx::Vsock(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), + } + } + + #[cfg(feature = "vsock")] + pub fn into_vsock(self) -> AxResult { + match self { + SocketAddrEx::Ip(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), + SocketAddrEx::Unix(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), + SocketAddrEx::Vsock(addr) => Ok(addr), + } + } +} + +bitflags! { + /// Flags for sending data to a socket. + /// + /// See [`SocketOps::send`]. + #[derive(Default, Debug, Clone, Copy)] + pub struct SendFlags: u32 { + } +} + +bitflags! { + /// Flags for receiving data from a socket. + /// + /// See [`SocketOps::recv`]. + #[derive(Default, Debug, Clone, Copy)] + pub struct RecvFlags: u32 { + /// Receive data without removing it from the queue. + const PEEK = 0x01; + /// For datagram-like sockets, requires [`SocketOps::recv`] to return + /// the real size of the datagram, even when it is larger than the + /// buffer. + const TRUNCATE = 0x02; + } +} + +pub type CMsgData = Box; + +/// Options for sending data to a socket. +/// +/// See [`SocketOps::send`]. +#[derive(Default, Debug)] +pub struct SendOptions { + pub to: Option, + pub flags: SendFlags, + pub cmsg: Vec, +} + +/// Options for receiving data from a socket. +/// +/// See [`SocketOps::recv`]. +#[derive(Default)] +pub struct RecvOptions<'a> { + pub from: Option<&'a mut SocketAddrEx>, + pub flags: RecvFlags, + pub cmsg: Option<&'a mut Vec>, +} +impl Debug for RecvOptions<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RecvOptions") + .field("from", &self.from) + .field("flags", &self.flags) + .finish() + } +} + +/// Kind of shutdown operation to perform on a socket. +#[derive(Debug, Clone, Copy)] +pub enum Shutdown { + Read, + Write, + Both, +} +impl Shutdown { + pub fn has_read(&self) -> bool { + matches!(self, Shutdown::Read | Shutdown::Both) + } + + pub fn has_write(&self) -> bool { + matches!(self, Shutdown::Write | Shutdown::Both) + } +} + +/// Operations that can be performed on a socket. +#[enum_dispatch] +pub trait SocketOps: Configurable { + /// Binds an unbound socket to the given address and port. + fn bind(&self, local_addr: SocketAddrEx) -> AxResult; + /// Connects the socket to a remote address. + fn connect(&self, remote_addr: SocketAddrEx) -> AxResult; + + /// Starts listening on the bound address and port. + fn listen(&self) -> AxResult { + Err(AxError::OperationNotSupported) + } + /// Accepts a connection on a listening socket, returning a new socket. + fn accept(&self) -> AxResult { + Err(AxError::OperationNotSupported) + } + + /// Send data to the socket, optionally to a specific address. + fn send(&self, src: &mut impl Buf, options: SendOptions) -> AxResult; + /// Receive data from the socket. + fn recv(&self, dst: &mut impl BufMut, options: RecvOptions<'_>) -> AxResult; + + /// Get the local endpoint of the socket. + fn local_addr(&self) -> AxResult; + /// Get the remote endpoint of the socket. + fn peer_addr(&self) -> AxResult; + + /// Shutdown the socket, closing the connection. + fn shutdown(&self, how: Shutdown) -> AxResult; +} + +/// Network socket abstraction. +#[enum_dispatch(Configurable, SocketOps)] +pub enum Socket { + Udp(UdpSocket), + Tcp(TcpSocket), + Unix(UnixSocket), + #[cfg(feature = "vsock")] + Vsock(VsockSocket), +} + +impl Pollable for Socket { + fn poll(&self) -> IoEvents { + match self { + Socket::Tcp(tcp) => tcp.poll(), + Socket::Udp(udp) => udp.poll(), + Socket::Unix(unix) => unix.poll(), + #[cfg(feature = "vsock")] + Socket::Vsock(vsock) => vsock.poll(), + } + } + + fn register(&self, context: &mut Context<'_>, events: IoEvents) { + match self { + Socket::Tcp(tcp) => tcp.register(context, events), + Socket::Udp(udp) => udp.register(context, events), + Socket::Unix(unix) => unix.register(context, events), + #[cfg(feature = "vsock")] + Socket::Vsock(vsock) => vsock.register(context, events), + } + } +} diff --git a/modules/axnet/src/state.rs b/modules/axnet/src/state.rs new file mode 100644 index 0000000000..a3a28b2ec5 --- /dev/null +++ b/modules/axnet/src/state.rs @@ -0,0 +1,77 @@ +use core::sync::atomic::{AtomicU8, Ordering}; + +use axerrno::AxResult; + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum State { + Idle, + Busy, + Connecting, + Connected, + Listening, + Closed, +} + +impl TryFrom for State { + type Error = (); + + fn try_from(value: u8) -> Result { + Ok(match value { + 0 => State::Idle, + 1 => State::Busy, + 2 => State::Connecting, + 3 => State::Connected, + 4 => State::Listening, + 5 => State::Closed, + _ => return Err(()), + }) + } +} + +pub struct StateLock(AtomicU8); +impl StateLock { + pub fn new(state: State) -> Self { + Self(AtomicU8::new(state as u8)) + } + + pub fn get(&self) -> State { + self.0 + .load(Ordering::Acquire) + .try_into() + .expect("invalid state") + } + + pub fn set(&self, state: State) { + self.0.store(state as u8, Ordering::Release); + } + + pub fn lock(&self, expect: State) -> Result { + match self.0.compare_exchange( + expect as u8, + State::Busy as u8, + Ordering::Acquire, + Ordering::Acquire, + ) { + Ok(_) => Ok(StateGuard(self, expect as u8)), + Err(old) => Err(old.try_into().expect("invalid state")), + } + } +} + +#[must_use] +pub struct StateGuard<'a>(&'a StateLock, u8); +impl StateGuard<'_> { + pub fn transit(self, new: State, f: impl FnOnce() -> AxResult) -> AxResult { + match f() { + Ok(result) => { + self.0.0.store(new as u8, Ordering::Release); + Ok(result) + } + Err(err) => { + self.0.0.store(self.1, Ordering::Release); + Err(err) + } + } + } +} diff --git a/modules/axnet/src/tcp.rs b/modules/axnet/src/tcp.rs new file mode 100644 index 0000000000..0182ee50fa --- /dev/null +++ b/modules/axnet/src/tcp.rs @@ -0,0 +1,517 @@ +use alloc::vec; +use core::{ + net::{Ipv4Addr, SocketAddr}, + sync::atomic::{AtomicBool, Ordering}, + task::Context, +}; + +use axerrno::{AxError, AxResult, ax_bail, ax_err_type}; +use axio::{Buf, BufMut}; +use axpoll::{IoEvents, PollSet, Pollable}; +use axsync::Mutex; +use smoltcp::{ + iface::SocketHandle, + socket::tcp as smol, + time::Duration, + wire::{IpEndpoint, IpListenEndpoint}, +}; + +use super::{LISTEN_TABLE, SOCKET_SET}; +use crate::{ + RecvFlags, RecvOptions, SERVICE, SendOptions, Shutdown, Socket, SocketAddrEx, SocketOps, + consts::{TCP_RX_BUF_LEN, TCP_TX_BUF_LEN}, + general::GeneralOptions, + options::{Configurable, GetSocketOption, SetSocketOption}, + poll_interfaces, + state::*, +}; + +pub(crate) fn new_tcp_socket() -> smol::Socket<'static> { + smol::Socket::new( + smol::SocketBuffer::new(vec![0; TCP_RX_BUF_LEN]), + smol::SocketBuffer::new(vec![0; TCP_TX_BUF_LEN]), + ) +} + +/// A TCP socket that provides POSIX-like APIs. +pub struct TcpSocket { + state: StateLock, + handle: SocketHandle, + + general: GeneralOptions, + rx_closed: AtomicBool, + poll_rx_closed: PollSet, +} + +unsafe impl Sync for TcpSocket {} + +impl TcpSocket { + /// Creates a new TCP socket. + pub fn new() -> Self { + Self { + state: StateLock::new(State::Idle), + handle: SOCKET_SET.add(new_tcp_socket()), + + general: GeneralOptions::new(), + rx_closed: AtomicBool::new(false), + poll_rx_closed: PollSet::new(), + } + } + + /// Creates a new TCP socket that is already connected. + fn new_connected(handle: SocketHandle) -> Self { + let result = Self { + state: StateLock::new(State::Connected), + handle, + + general: GeneralOptions::new(), + rx_closed: AtomicBool::new(false), + poll_rx_closed: PollSet::new(), + }; + result.with_smol_socket(|socket| { + result + .general + .set_device_mask(SERVICE.lock().device_mask_for(&socket.get_bound_endpoint())); + }); + result + } +} + +/// Private methods +impl TcpSocket { + fn state(&self) -> State { + self.state.get() + } + + #[inline] + fn is_listening(&self) -> bool { + self.state() == State::Listening + } + + fn with_smol_socket(&self, f: impl FnOnce(&mut smol::Socket) -> R) -> R { + SOCKET_SET.with_socket_mut::(self.handle, f) + } + + fn bound_endpoint(&self) -> AxResult { + let endpoint = self.with_smol_socket(|socket| socket.get_bound_endpoint()); + if endpoint.port == 0 { + ax_bail!(InvalidInput, "not bound"); + } + Ok(endpoint) + } + + fn poll_connect(&self) -> IoEvents { + let mut events = IoEvents::empty(); + let writable = self.with_smol_socket(|socket| match socket.state() { + smol::State::SynSent => false, // wait for connection + smol::State::Established => { + self.state.set(State::Connected); // connected + debug!( + "TCP socket {}: connected to {}", + self.handle, + socket.remote_endpoint().unwrap(), + ); + true + } + _ => { + self.state.set(State::Closed); // connection failed + true + } + }); + events.set(IoEvents::OUT, writable); + events + } + + fn poll_stream(&self) -> IoEvents { + let mut events = IoEvents::empty(); + self.with_smol_socket(|socket| { + events.set( + IoEvents::IN, + !self.rx_closed.load(Ordering::Acquire) + && (!socket.may_recv() || socket.can_recv()), + ); + events.set(IoEvents::OUT, !socket.may_send() || socket.can_send()); + }); + events + } + + fn poll_listener(&self) -> IoEvents { + let mut events = IoEvents::empty(); + events.set( + IoEvents::IN, + LISTEN_TABLE + .can_accept(self.bound_endpoint().unwrap().port) + .unwrap(), + ); + events + } +} + +impl Configurable for TcpSocket { + fn get_option_inner(&self, option: &mut GetSocketOption) -> AxResult { + use GetSocketOption as O; + + if self.general.get_option_inner(option)? { + return Ok(true); + } + + match option { + O::NoDelay(no_delay) => { + **no_delay = self.with_smol_socket(|socket| !socket.nagle_enabled()); + } + O::KeepAlive(keep_alive) => { + **keep_alive = self.with_smol_socket(|socket| socket.keep_alive().is_some()); + } + O::MaxSegment(max_segment) => { + // TODO(mivik): get actual MSS + **max_segment = 1460; + } + O::SendBuffer(size) => { + **size = TCP_TX_BUF_LEN; + } + O::ReceiveBuffer(size) => { + **size = TCP_RX_BUF_LEN; + } + O::TcpInfo(_) => { + // TODO(mivik): implement TCP_INFO + } + _ => return Ok(false), + } + Ok(true) + } + + fn set_option_inner(&self, option: SetSocketOption) -> AxResult { + use SetSocketOption as O; + + if self.general.set_option_inner(option)? { + return Ok(true); + } + + match option { + O::NoDelay(no_delay) => { + self.with_smol_socket(|socket| { + socket.set_nagle_enabled(!no_delay); + }); + } + O::KeepAlive(keep_alive) => { + self.with_smol_socket(|socket| { + socket.set_keep_alive(keep_alive.then(|| Duration::from_secs(75))); + }); + } + _ => return Ok(false), + } + Ok(true) + } +} +impl SocketOps for TcpSocket { + fn bind(&self, local_addr: SocketAddrEx) -> AxResult { + let mut local_addr = local_addr.into_ip()?; + self.state + .lock(State::Idle) + .map_err(|_| ax_err_type!(InvalidInput, "already bound"))? + .transit(State::Idle, || { + // TODO: check addr is available + if local_addr.port() == 0 { + local_addr.set_port(get_ephemeral_port()?); + } + if !self.general.reuse_address() { + SOCKET_SET.bind_check(local_addr.ip().into(), local_addr.port())?; + } + + self.with_smol_socket(|socket| { + if socket.get_bound_endpoint().port != 0 { + return Err(AxError::InvalidInput); + } + let endpoint = IpListenEndpoint { + addr: if local_addr.ip().is_unspecified() { + None + } else { + Some(local_addr.ip().into()) + }, + port: local_addr.port(), + }; + socket.set_bound_endpoint(endpoint); + self.general + .set_device_mask(SERVICE.lock().device_mask_for(&endpoint)); + Ok(()) + })?; + debug!("TCP socket {}: binding to {}", self.handle, local_addr); + Ok(()) + }) + } + + fn connect(&self, remote_addr: SocketAddrEx) -> AxResult { + let remote_addr = remote_addr.into_ip()?; + self.state + .lock(State::Idle) + .map_err(|state| { + if state == State::Connecting { + AxError::InProgress + } else { + // TODO(mivik): error code + ax_err_type!(AlreadyConnected) + } + })? + .transit(State::Connecting, || { + // TODO: check remote addr unreachable + // let (bound_endpoint, remote_endpoint) = self.get_endpoint_pair(remote_addr)?; + let remote_endpoint = IpEndpoint::from(remote_addr); + let mut bound_endpoint = + self.with_smol_socket(|socket| socket.get_bound_endpoint()); + if bound_endpoint.addr.is_none() { + bound_endpoint.addr = + Some(SERVICE.lock().get_source_address(&remote_endpoint.addr)); + } + if bound_endpoint.port == 0 { + bound_endpoint.port = get_ephemeral_port()?; + } + info!( + "TCP connection from {} to {}", + bound_endpoint, remote_endpoint + ); + + self.with_smol_socket(|socket| { + socket.set_bound_endpoint(bound_endpoint); + self.general + .set_device_mask(SERVICE.lock().device_mask_for(&bound_endpoint)); + socket + .connect( + crate::SERVICE.lock().iface.context(), + remote_endpoint, + bound_endpoint, + ) + .map_err(|e| match e { + smol::ConnectError::InvalidState => { + ax_err_type!(AlreadyConnected) + } + smol::ConnectError::Unaddressable => { + ax_err_type!(ConnectionRefused, "unaddressable") + } + })?; + Ok(()) + }) + })?; + + // Hack: let the server listen + axtask::yield_now(); + + // Here our state must be `CONNECTING`, and only one thread can run here. + self.general.send_poller(self, || { + poll_interfaces(); + let events = self.poll_connect(); + if !events.contains(IoEvents::OUT) { + Err(AxError::WouldBlock) + } else if self.state() == State::Connected { + Ok(()) + } else { + Err(ax_err_type!(ConnectionRefused, "connection refused")) + } + }) + } + + fn listen(&self) -> AxResult { + if let Ok(guard) = self.state.lock(State::Idle) { + guard.transit(State::Listening, || { + let bound_endpoint = self.with_smol_socket(|socket| socket.get_bound_endpoint()); + LISTEN_TABLE.listen(bound_endpoint)?; + debug!("listening on {}", bound_endpoint); + Ok(()) + })?; + } else { + // ignore simultaneous `listen`s. + } + Ok(()) + } + + fn accept(&self) -> AxResult { + if !self.is_listening() { + ax_bail!(InvalidInput, "not listening"); + } + + let bound_port = self.bound_endpoint()?.port; + self.general.recv_poller(self, || { + poll_interfaces(); + LISTEN_TABLE.accept(bound_port).map(|handle| { + let socket = TcpSocket::new_connected(handle); + debug!( + "accepted connection from {}, {}", + handle, + socket.with_smol_socket(|socket| socket.remote_endpoint().unwrap()) + ); + Socket::Tcp(socket) + }) + }) + } + + fn send(&self, src: &mut impl Buf, _options: SendOptions) -> AxResult { + // SAFETY: `self.handle` should be initialized in a connected socket. + self.general.send_poller(self, || { + poll_interfaces(); + self.with_smol_socket(|socket| { + if !socket.is_active() { + Err(AxError::NotConnected) + } else if !socket.can_send() { + Err(AxError::WouldBlock) + } else { + // connected, and the tx buffer is not full + let len = socket + .send(|buffer| { + let result = src.read(buffer); + let len = result.unwrap_or(0); + (len, result) + }) + .map_err(|_| ax_err_type!(NotConnected, "not connected?"))??; + Ok(len) + } + }) + }) + } + + fn recv(&self, dst: &mut impl BufMut, options: RecvOptions<'_>) -> AxResult { + if self.rx_closed.load(Ordering::Acquire) { + return Err(AxError::NotConnected); + } + self.general.recv_poller(self, || { + poll_interfaces(); + self.with_smol_socket(|socket| { + if !socket.is_active() { + Err(AxError::NotConnected) + } else if !socket.may_recv() { + Ok(0) + } else if socket.recv_queue() == 0 { + Err(AxError::WouldBlock) + } else { + if options.flags.contains(RecvFlags::PEEK) { + dst.write( + socket + .peek(dst.remaining_mut()) + .map_err(|_| ax_err_type!(NotConnected, "not connected?"))?, + ) + } else { + socket + .recv(|buf| { + let result = dst.write(buf); + let len = result.unwrap_or(0); + (len, result) + }) + .map_err(|_| ax_err_type!(NotConnected, "not connected?"))? + } + } + }) + }) + } + + fn local_addr(&self) -> AxResult { + self.with_smol_socket(|socket| { + let endpoint = socket.get_bound_endpoint(); + Ok(SocketAddrEx::Ip(SocketAddr::new( + endpoint + .addr + .map_or_else(|| Ipv4Addr::UNSPECIFIED.into(), Into::into), + endpoint.port, + ))) + }) + } + + fn peer_addr(&self) -> AxResult { + self.with_smol_socket(|socket| { + Ok(SocketAddrEx::Ip( + socket + .remote_endpoint() + .ok_or(AxError::NotConnected)? + .into(), + )) + }) + } + + fn shutdown(&self, how: Shutdown) -> AxResult { + // TODO(mivik): shutdown + if how.has_read() { + self.rx_closed.store(true, Ordering::Release); + self.poll_rx_closed.wake(); + } + + // stream + if let Ok(guard) = self.state.lock(State::Connected) { + guard.transit(State::Closed, || { + if how.has_write() { + self.with_smol_socket(|socket| { + debug!("TCP socket {}: shutting down", self.handle); + socket.close(); + }); + } + poll_interfaces(); + Ok(()) + })?; + } + + // listener + if let Ok(guard) = self.state.lock(State::Listening) { + guard.transit(State::Closed, || { + LISTEN_TABLE.unlisten(self.bound_endpoint()?.port); + poll_interfaces(); + Ok(()) + })?; + } + + // ignore for other states + Ok(()) + } +} + +impl Pollable for TcpSocket { + fn poll(&self) -> IoEvents { + poll_interfaces(); + let mut events = match self.state() { + State::Connecting => self.poll_connect(), + State::Connected | State::Idle | State::Closed => self.poll_stream(), + State::Listening => self.poll_listener(), + State::Busy => IoEvents::empty(), + }; + events.set(IoEvents::RDHUP, self.rx_closed.load(Ordering::Acquire)); + events + } + + fn register(&self, context: &mut Context<'_>, events: IoEvents) { + if events.intersects(IoEvents::IN | IoEvents::OUT | IoEvents::RDHUP) { + self.general.register_waker(context.waker()); + } + if events.contains(IoEvents::RDHUP) { + self.poll_rx_closed.register(context.waker()); + } + } +} + +impl Drop for TcpSocket { + fn drop(&mut self) { + if let Err(err) = self.shutdown(Shutdown::Both) { + warn!("TCP socket {}: shutdown failed: {}", self.handle, err); + } + SOCKET_SET.remove(self.handle); + // This is crucial for the close messages to be sent. + poll_interfaces(); + } +} + +fn get_ephemeral_port() -> AxResult { + const PORT_START: u16 = 0xc000; + const PORT_END: u16 = 0xffff; + static CURR: Mutex = Mutex::new(PORT_START); + + let mut curr = CURR.lock(); + let mut tries = 0; + // TODO: more robust + while tries <= PORT_END - PORT_START { + let port = *curr; + if *curr == PORT_END { + *curr = PORT_START; + } else { + *curr += 1; + } + if LISTEN_TABLE.can_listen(port) { + return Ok(port); + } + tries += 1; + } + ax_bail!(AddrInUse, "no available ports"); +} diff --git a/modules/axnet/src/udp.rs b/modules/axnet/src/udp.rs new file mode 100644 index 0000000000..0fcb39413c --- /dev/null +++ b/modules/axnet/src/udp.rs @@ -0,0 +1,355 @@ +use alloc::vec; +use core::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + task::Context, +}; + +use axerrno::{AxError, AxResult, ax_bail, ax_err_type}; +use axio::{Buf, BufMut}; +use axpoll::{IoEvents, Pollable}; +use axsync::Mutex; +use smoltcp::{ + iface::SocketHandle, + phy::PacketMeta, + socket::udp::{self as smol, UdpMetadata}, + storage::PacketMetadata, + wire::{IpAddress, IpEndpoint, IpListenEndpoint}, +}; +use spin::RwLock; + +use crate::{ + RecvFlags, RecvOptions, SERVICE, SOCKET_SET, SendOptions, Shutdown, SocketAddrEx, SocketOps, + consts::{UDP_RX_BUF_LEN, UDP_TX_BUF_LEN}, + general::GeneralOptions, + options::{Configurable, GetSocketOption, SetSocketOption}, + poll_interfaces, +}; + +pub(crate) fn new_udp_socket() -> smol::Socket<'static> { + // TODO(mivik): buffer size + smol::Socket::new( + smol::PacketBuffer::new(vec![PacketMetadata::EMPTY; 256], vec![0; UDP_RX_BUF_LEN]), + smol::PacketBuffer::new(vec![PacketMetadata::EMPTY; 256], vec![0; UDP_TX_BUF_LEN]), + ) +} + +/// A UDP socket that provides POSIX-like APIs. +pub struct UdpSocket { + handle: SocketHandle, + local_addr: RwLock>, + peer_addr: RwLock>, + + general: GeneralOptions, +} + +impl UdpSocket { + /// Creates a new UDP socket. + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let socket = new_udp_socket(); + let handle = SOCKET_SET.add(socket); + + Self { + handle, + local_addr: RwLock::new(None), + peer_addr: RwLock::new(None), + + general: GeneralOptions::new(), + } + } + + fn with_smol_socket(&self, f: impl FnOnce(&mut smol::Socket) -> R) -> R { + SOCKET_SET.with_socket_mut::(self.handle, f) + } + + fn remote_endpoint(&self) -> AxResult<(IpEndpoint, IpAddress)> { + match self.peer_addr.try_read() { + Some(addr) => addr.ok_or(AxError::NotConnected), + None => Err(AxError::NotConnected), + } + } +} + +impl Configurable for UdpSocket { + fn get_option_inner(&self, option: &mut GetSocketOption) -> AxResult { + use GetSocketOption as O; + + if self.general.get_option_inner(option)? { + return Ok(true); + } + match option { + O::Ttl(ttl) => { + self.with_smol_socket(|socket| { + **ttl = socket.hop_limit().unwrap_or(64); + }); + } + O::SendBuffer(size) => { + **size = UDP_TX_BUF_LEN; + } + O::ReceiveBuffer(size) => { + **size = UDP_RX_BUF_LEN; + } + _ => return Ok(false), + } + Ok(true) + } + + fn set_option_inner(&self, option: SetSocketOption) -> AxResult { + use SetSocketOption as O; + + if self.general.set_option_inner(option)? { + return Ok(true); + } + match option { + O::Ttl(ttl) => { + self.with_smol_socket(|socket| { + socket.set_hop_limit(Some(*ttl)); + }); + } + _ => return Ok(false), + } + Ok(true) + } +} +impl SocketOps for UdpSocket { + fn bind(&self, local_addr: SocketAddrEx) -> AxResult { + let mut local_addr = local_addr.into_ip()?; + let mut guard = self.local_addr.write(); + + if local_addr.port() == 0 { + local_addr.set_port(get_ephemeral_port()?); + } + if guard.is_some() { + ax_bail!(InvalidInput, "already bound"); + } + + let local_endpoint = IpEndpoint::from(local_addr); + let endpoint = IpListenEndpoint { + addr: (!local_endpoint.addr.is_unspecified()).then_some(local_endpoint.addr), + port: local_endpoint.port, + }; + + if !self.general.reuse_address() { + // Check if the address is already in use + SOCKET_SET.bind_check(local_endpoint.addr, local_endpoint.port)?; + } + + self.with_smol_socket(|socket| { + socket.bind(endpoint).map_err(|e| match e { + smol::BindError::InvalidState => ax_err_type!(InvalidInput, "already bound"), + smol::BindError::Unaddressable => ax_err_type!(ConnectionRefused, "unaddressable"), + }) + })?; + self.general + .set_device_mask(SERVICE.lock().device_mask_for(&endpoint)); + + *guard = Some(local_endpoint); + info!("UDP socket {}: bound on {}", self.handle, endpoint); + Ok(()) + } + + fn connect(&self, remote_addr: SocketAddrEx) -> AxResult { + let remote_addr = remote_addr.into_ip()?; + let mut guard = self.peer_addr.write(); + if self.local_addr.read().is_none() { + self.bind(SocketAddrEx::Ip(SocketAddr::new( + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + 0, + )))?; + } + + let remote_addr = IpEndpoint::from(remote_addr); + let src = SERVICE.lock().get_source_address(&remote_addr.addr); + *guard = Some((remote_addr, src)); + debug!("UDP socket {}: connected to {}", self.handle, remote_addr); + Ok(()) + } + + fn send(&self, src: &mut impl Buf, options: SendOptions) -> AxResult { + let (remote_addr, source_addr) = match options.to { + Some(addr) => { + let addr = IpEndpoint::from(addr.into_ip()?); + let src = SERVICE.lock().get_source_address(&addr.addr); + (addr, src) + } + None => self.remote_endpoint()?, + }; + if remote_addr.port == 0 || remote_addr.addr.is_unspecified() { + ax_bail!(InvalidInput, "invalid address"); + } + + if self.local_addr.read().is_none() { + self.bind(SocketAddrEx::Ip(SocketAddr::new( + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + 0, + )))?; + } + self.general.send_poller(self, || { + poll_interfaces(); + self.with_smol_socket(|socket| { + if !socket.is_open() { + // not connected + Err(ax_err_type!(NotConnected)) + } else if !socket.can_send() { + Err(AxError::WouldBlock) + } else { + let buf = socket + .send( + src.remaining(), + UdpMetadata { + endpoint: remote_addr, + local_address: Some(source_addr), + meta: PacketMeta::default(), + }, + ) + .map_err(|e| match e { + smol::SendError::BufferFull => AxError::WouldBlock, + smol::SendError::Unaddressable => { + ax_err_type!(ConnectionRefused, "unaddressable") + } + })?; + let read = src.read(buf)?; + assert_eq!(read, buf.len()); + Ok(read) + } + }) + }) + } + + fn recv(&self, dst: &mut impl BufMut, options: RecvOptions) -> AxResult { + if self.local_addr.read().is_none() { + ax_bail!(NotConnected); + } + + enum ExpectedRemote<'a> { + Any(&'a mut SocketAddrEx), + Expecting(IpEndpoint), + } + let mut expected_remote = match options.from { + Some(addr) => ExpectedRemote::Any(addr), + None => ExpectedRemote::Expecting(self.remote_endpoint()?.0), + }; + + self.general.recv_poller(self, || { + poll_interfaces(); + self.with_smol_socket(|socket| { + if !socket.is_open() { + // not bound + Err(ax_err_type!(NotConnected)) + } else if !socket.can_recv() { + Err(AxError::WouldBlock) + } else { + let result = if options.flags.contains(RecvFlags::PEEK) { + socket.peek().map(|(data, meta)| (data, *meta)) + } else { + socket.recv() + }; + match result { + Ok((src, meta)) => { + match &mut expected_remote { + ExpectedRemote::Any(remote_addr) => { + **remote_addr = SocketAddrEx::Ip(meta.endpoint.into()); + } + ExpectedRemote::Expecting(expected) => { + if (!expected.addr.is_unspecified() + && expected.addr != meta.endpoint.addr) + || (expected.port != 0 + && expected.port != meta.endpoint.port) + { + return Err(AxError::WouldBlock); + } + } + } + + let read = dst.write(src)?; + if read < src.len() { + warn!("UDP message truncated: {} -> {} bytes", src.len(), read); + } + + Ok(if options.flags.contains(RecvFlags::TRUNCATE) { + src.len() + } else { + read + }) + } + Err(smol::RecvError::Exhausted) => Err(AxError::WouldBlock), + Err(smol::RecvError::Truncated) => { + unreachable!("UDP socket recv never returns Err(Truncated)") + } + } + } + }) + }) + } + + fn local_addr(&self) -> AxResult { + match self.local_addr.try_read() { + Some(addr) => addr + .map(Into::into) + .map(SocketAddrEx::Ip) + .ok_or(AxError::NotConnected), + None => Err(AxError::NotConnected), + } + } + + fn peer_addr(&self) -> AxResult { + self.remote_endpoint() + .map(|it| it.0.into()) + .map(SocketAddrEx::Ip) + } + + fn shutdown(&self, _how: Shutdown) -> AxResult { + // TODO(mivik): shutdown + poll_interfaces(); + + self.with_smol_socket(|socket| { + debug!("UDP socket {}: shutting down", self.handle); + socket.close(); + }); + Ok(()) + } +} + +impl Pollable for UdpSocket { + fn poll(&self) -> IoEvents { + poll_interfaces(); + if self.local_addr.read().is_none() { + return IoEvents::empty(); + } + + let mut events = IoEvents::empty(); + self.with_smol_socket(|socket| { + events.set(IoEvents::IN, socket.can_recv()); + events.set(IoEvents::OUT, socket.can_send()); + }); + events + } + + fn register(&self, context: &mut Context<'_>, events: IoEvents) { + if events.intersects(IoEvents::IN | IoEvents::OUT) { + self.general.register_waker(context.waker()); + } + } +} + +impl Drop for UdpSocket { + fn drop(&mut self) { + self.shutdown(Shutdown::Both).ok(); + SOCKET_SET.remove(self.handle); + } +} + +fn get_ephemeral_port() -> AxResult { + const PORT_START: u16 = 0xc000; + const PORT_END: u16 = 0xffff; + static CURR: Mutex = Mutex::new(PORT_START); + let mut curr = CURR.lock(); + + let port = *curr; + if *curr == PORT_END { + *curr = PORT_START; + } else { + *curr += 1; + } + Ok(port) +} diff --git a/modules/axnet/src/unix.rs b/modules/axnet/src/unix.rs new file mode 100644 index 0000000000..48146294e6 --- /dev/null +++ b/modules/axnet/src/unix.rs @@ -0,0 +1,229 @@ +pub(crate) mod dgram; +pub(crate) mod stream; + +use alloc::{boxed::Box, sync::Arc}; +use core::task::Context; + +use async_trait::async_trait; +use axerrno::{AxError, AxResult}; +use axfs::{FS_CONTEXT, OpenOptions}; +use axfs_ng_vfs::NodeType; +use axio::{Buf, BufMut}; +use axpoll::{IoEvents, Pollable}; +use axsync::Mutex; +use axtask::future::{block_on, interruptible}; +use enum_dispatch::enum_dispatch; +use hashbrown::HashMap; +use lazy_static::lazy_static; + +pub use self::{dgram::DgramTransport, stream::StreamTransport}; +use crate::{ + RecvOptions, SendOptions, Shutdown, Socket, SocketAddrEx, SocketOps, + options::{Configurable, GetSocketOption, SetSocketOption}, +}; + +#[derive(Default, Clone, Debug)] +pub enum UnixSocketAddr { + #[default] + Unnamed, + Abstract(Arc<[u8]>), + Path(Arc), +} + +/// Abstract transport trait for Unix sockets. +#[async_trait] +#[enum_dispatch] +pub trait TransportOps: Configurable + Pollable + Send + Sync { + fn bind(&self, slot: &BindSlot, local_addr: &UnixSocketAddr) -> AxResult; + fn connect(&self, slot: &BindSlot, local_addr: &UnixSocketAddr) -> AxResult; + + async fn accept(&self) -> AxResult<(Transport, UnixSocketAddr)>; + + fn send(&self, src: &mut impl Buf, options: SendOptions) -> AxResult; + fn recv(&self, dst: &mut impl BufMut, options: RecvOptions<'_>) -> AxResult; + + fn shutdown(&self, _how: Shutdown) -> AxResult { + Ok(()) + } +} + +#[enum_dispatch(Configurable, TransportOps)] +pub enum Transport { + Stream(StreamTransport), + Dgram(DgramTransport), +} +impl Pollable for Transport { + fn poll(&self) -> IoEvents { + match self { + Transport::Stream(stream) => stream.poll(), + Transport::Dgram(dgram) => dgram.poll(), + } + } + + fn register(&self, context: &mut core::task::Context<'_>, events: IoEvents) { + match self { + Transport::Stream(stream) => stream.register(context, events), + Transport::Dgram(dgram) => dgram.register(context, events), + } + } +} + +#[derive(Default)] +pub struct BindSlot { + stream: Mutex>, + dgram: Mutex>, +} + +lazy_static! { + static ref ABSTRACT_BINDS: Mutex, BindSlot>> = Mutex::new(HashMap::new()); +} + +pub(crate) fn with_slot( + addr: &UnixSocketAddr, + f: impl FnOnce(&BindSlot) -> AxResult, +) -> AxResult { + match addr { + UnixSocketAddr::Unnamed => Err(AxError::InvalidInput), + UnixSocketAddr::Abstract(name) => { + let binds = ABSTRACT_BINDS.lock(); + if let Some(slot) = binds.get(name) { + f(slot) + } else { + Err(AxError::NotFound) + } + } + UnixSocketAddr::Path(path) => { + let loc = FS_CONTEXT.lock().resolve(path.as_ref())?; + if loc.metadata()?.node_type != NodeType::Socket { + return Err(AxError::NotASocket); + } + f(loc + .user_data() + .get::() + .ok_or(AxError::ConnectionRefused)? + .as_ref()) + } + } +} +fn with_slot_or_insert( + addr: &UnixSocketAddr, + f: impl FnOnce(&BindSlot) -> AxResult, +) -> AxResult { + match addr { + UnixSocketAddr::Unnamed => Err(AxError::InvalidInput), + UnixSocketAddr::Abstract(name) => { + let mut binds = ABSTRACT_BINDS.lock(); + f(binds.entry(name.clone()).or_default()) + } + UnixSocketAddr::Path(path) => { + let loc = OpenOptions::new() + .write(true) + .create(true) + .node_type(NodeType::Socket) + .open(&*FS_CONTEXT.lock(), path.as_ref())? + .into_location(); + if loc.metadata()?.node_type != NodeType::Socket { + return Err(AxError::NotASocket); + } + f(loc + .user_data() + .get_or_insert_with(|| BindSlot::default()) + .as_ref()) + } + } +} + +pub struct UnixSocket { + transport: Transport, + local_addr: Mutex, + remote_addr: Mutex, +} +impl UnixSocket { + pub fn new(transport: impl Into) -> Self { + Self { + transport: transport.into(), + local_addr: Mutex::new(UnixSocketAddr::Unnamed), + remote_addr: Mutex::new(UnixSocketAddr::Unnamed), + } + } +} +impl Configurable for UnixSocket { + fn get_option_inner(&self, opt: &mut GetSocketOption) -> AxResult { + self.transport.get_option_inner(opt) + } + + fn set_option_inner(&self, opt: SetSocketOption) -> AxResult { + self.transport.set_option_inner(opt) + } +} +impl SocketOps for UnixSocket { + fn bind(&self, local_addr: SocketAddrEx) -> AxResult { + let local_addr = local_addr.into_unix()?; + let mut guard = self.local_addr.lock(); + if matches!(&*guard, UnixSocketAddr::Unnamed) { + with_slot_or_insert(&local_addr, |slot| self.transport.bind(slot, &local_addr))?; + *guard = local_addr; + } else { + return Err(AxError::InvalidInput); + } + Ok(()) + } + + fn connect(&self, remote_addr: SocketAddrEx) -> AxResult { + let remote_addr = remote_addr.into_unix()?; + let local_addr = self.local_addr.lock().clone(); + let mut guard = self.remote_addr.lock(); + if matches!(&*guard, UnixSocketAddr::Unnamed) { + with_slot(&remote_addr, |slot| { + self.transport.connect(slot, &local_addr) + })?; + *guard = remote_addr; + } else { + return Err(AxError::InvalidInput); + } + Ok(()) + } + + fn listen(&self) -> AxResult { + Ok(()) + } + + fn accept(&self) -> AxResult { + let (transport, peer_addr) = block_on(interruptible(self.transport.accept()))??; + Ok(Socket::Unix(Self { + transport, + local_addr: Mutex::new(self.local_addr.lock().clone()), + remote_addr: Mutex::new(peer_addr), + })) + } + + fn send(&self, src: &mut impl Buf, options: SendOptions) -> AxResult { + self.transport.send(src, options) + } + + fn recv(&self, dst: &mut impl BufMut, options: RecvOptions<'_>) -> AxResult { + self.transport.recv(dst, options) + } + + fn local_addr(&self) -> AxResult { + Ok(SocketAddrEx::Unix(self.local_addr.lock().clone())) + } + + fn peer_addr(&self) -> AxResult { + Ok(SocketAddrEx::Unix(self.remote_addr.lock().clone())) + } + + fn shutdown(&self, how: Shutdown) -> AxResult { + self.transport.shutdown(how) + } +} + +impl Pollable for UnixSocket { + fn poll(&self) -> IoEvents { + self.transport.poll() + } + + fn register(&self, context: &mut Context<'_>, events: IoEvents) { + self.transport.register(context, events); + } +} diff --git a/modules/axnet/src/unix/dgram.rs b/modules/axnet/src/unix/dgram.rs new file mode 100644 index 0000000000..af3973c95c --- /dev/null +++ b/modules/axnet/src/unix/dgram.rs @@ -0,0 +1,278 @@ +use alloc::{boxed::Box, sync::Arc, vec::Vec}; +use core::task::Context; + +use async_channel::TryRecvError; +use async_trait::async_trait; +use axerrno::{AxError, AxResult}; +use axio::{Buf, BufMut}; +use axpoll::{IoEvents, PollSet, Pollable}; +use axsync::Mutex; +use spin::RwLock; + +use crate::{ + CMsgData, RecvFlags, RecvOptions, SendOptions, SocketAddrEx, + general::GeneralOptions, + options::{Configurable, GetSocketOption, SetSocketOption, UnixCredentials}, + unix::{Transport, TransportOps, UnixSocketAddr, with_slot}, +}; + +struct Packet { + data: Vec, + cmsg: Vec, + sender: UnixSocketAddr, +} + +struct Channel { + data_tx: async_channel::Sender, + poll_update: Arc, +} + +pub struct Bind { + data_tx: async_channel::Sender, + poll_update: Arc, +} +impl Bind { + fn connect(&self) -> Channel { + let tx = self.data_tx.clone(); + Channel { + data_tx: tx, + poll_update: self.poll_update.clone(), + } + } +} + +pub struct DgramTransport { + data_rx: Mutex, Arc)>>, + connected: RwLock>, + local_addr: RwLock, + poll_state: Arc, + general: GeneralOptions, + pid: u32, +} +impl DgramTransport { + pub fn new(pid: u32) -> Self { + DgramTransport { + data_rx: Mutex::new(None), + connected: RwLock::new(None), + local_addr: RwLock::new(UnixSocketAddr::Unnamed), + poll_state: Arc::default(), + general: GeneralOptions::default(), + pid, + } + } + + fn new_connected( + data_rx: (async_channel::Receiver, Arc), + connected: Channel, + pid: u32, + ) -> Self { + DgramTransport { + data_rx: Mutex::new(Some(data_rx)), + connected: RwLock::new(Some(connected)), + local_addr: RwLock::new(UnixSocketAddr::Unnamed), + poll_state: Arc::default(), + general: GeneralOptions::default(), + pid, + } + } + + pub fn new_pair(pid: u32) -> (Self, Self) { + let (tx1, rx1) = async_channel::unbounded(); + let (tx2, rx2) = async_channel::unbounded(); + let poll1 = Arc::new(PollSet::new()); + let poll2 = Arc::new(PollSet::new()); + let transport1 = DgramTransport::new_connected( + (rx1, poll1.clone()), + Channel { + data_tx: tx2, + poll_update: poll2.clone(), + }, + pid, + ); + let transport2 = DgramTransport::new_connected( + (rx2, poll2.clone()), + Channel { + data_tx: tx1, + poll_update: poll1.clone(), + }, + pid, + ); + (transport1, transport2) + } +} + +impl Configurable for DgramTransport { + fn get_option_inner(&self, opt: &mut GetSocketOption) -> AxResult { + use GetSocketOption as O; + + if self.general.get_option_inner(opt)? { + return Ok(true); + } + + match opt { + O::PassCredentials(_) => {} + O::PeerCredentials(cred) => { + // Datagram sockets are stateless and do not have a peer, so we + // return the credentials of the process that created the + // socket. + **cred = UnixCredentials::new(self.pid); + } + _ => return Ok(false), + } + Ok(true) + } + + fn set_option_inner(&self, opt: SetSocketOption) -> AxResult { + use SetSocketOption as O; + + if self.general.set_option_inner(opt)? { + return Ok(true); + } + + match opt { + O::PassCredentials(_) => {} + _ => return Ok(false), + } + Ok(true) + } +} +#[async_trait] +impl TransportOps for DgramTransport { + fn bind(&self, slot: &super::BindSlot, local_addr: &UnixSocketAddr) -> AxResult { + let mut slot = slot.dgram.lock(); + if slot.is_some() { + return Err(AxError::AddrInUse); + } + let mut guard = self.data_rx.lock(); + if guard.is_some() { + return Err(AxError::InvalidInput); + } + let (tx, rx) = async_channel::unbounded(); + let poll_update = Arc::new(PollSet::new()); + *slot = Some(Bind { + data_tx: tx, + poll_update: poll_update.clone(), + }); + *guard = Some((rx, poll_update)); + self.local_addr.write().clone_from(local_addr); + self.poll_state.wake(); + Ok(()) + } + + fn connect(&self, slot: &super::BindSlot, _local_addr: &UnixSocketAddr) -> AxResult { + let mut guard = self.connected.write(); + if guard.is_some() { + return Err(AxError::AlreadyConnected); + } + *guard = Some( + slot.dgram + .lock() + .as_ref() + .ok_or(AxError::NotConnected)? + .connect(), + ); + self.poll_state.wake(); + Ok(()) + } + + async fn accept(&self) -> AxResult<(Transport, UnixSocketAddr)> { + Err(AxError::InvalidInput) + } + + fn send(&self, src: &mut impl Buf, options: SendOptions) -> AxResult { + let mut message = Vec::new(); + src.read_to_end(&mut message)?; + let len = message.len(); + let packet = Packet { + data: message, + cmsg: options.cmsg, + sender: self.local_addr.read().clone(), + }; + + let connected = self.connected.read(); + if let Some(addr) = options.to { + let addr = addr.into_unix()?; + with_slot(&addr, |slot| { + if let Some(bind) = slot.dgram.lock().as_ref() { + bind.data_tx + .try_send(packet) + .map_err(|_| AxError::BrokenPipe)?; + bind.poll_update.wake(); + Ok(()) + } else { + Err(AxError::NotConnected) + } + })?; + } else if let Some(chan) = connected.as_ref() { + chan.data_tx + .try_send(packet) + .map_err(|_| AxError::BrokenPipe)?; + chan.poll_update.wake(); + } else { + return Err(AxError::NotConnected); + } + Ok(len) + } + + fn recv(&self, dst: &mut impl BufMut, mut options: RecvOptions) -> AxResult { + self.general.recv_poller(self, move || { + let mut guard = self.data_rx.lock(); + let Some((rx, _)) = guard.as_mut() else { + return Err(AxError::NotConnected); + }; + + let Packet { data, cmsg, sender } = match rx.try_recv() { + Ok(packet) => packet, + Err(TryRecvError::Empty) => { + return Err(AxError::WouldBlock); + } + Err(TryRecvError::Closed) => { + return Ok(0); + } + }; + let count = dst.write(&data)?; + if count < data.len() { + warn!("UDP message truncated: {} -> {} bytes", data.len(), count); + } + + if let Some(from) = options.from.as_mut() { + **from = SocketAddrEx::Unix(sender); + } + if let Some(dst) = options.cmsg.as_mut() { + dst.extend(cmsg); + } + + Ok(if options.flags.contains(RecvFlags::TRUNCATE) { + data.len() + } else { + count + }) + }) + } +} + +impl Pollable for DgramTransport { + fn poll(&self) -> IoEvents { + let mut events = IoEvents::OUT; + if let Some((rx, _)) = self.data_rx.lock().as_ref() { + events.set(IoEvents::IN, !rx.is_empty()); + } + events + } + + fn register(&self, context: &mut Context<'_>, events: IoEvents) { + if let Some((_, poll)) = self.data_rx.lock().as_ref() { + if events.contains(IoEvents::IN) { + poll.register(context.waker()); + } + } + } +} + +impl Drop for DgramTransport { + fn drop(&mut self) { + if let Some(chan) = self.connected.write().take() { + chan.poll_update.wake(); + } + } +} diff --git a/modules/axnet/src/unix/stream.rs b/modules/axnet/src/unix/stream.rs new file mode 100644 index 0000000000..e00dc11569 --- /dev/null +++ b/modules/axnet/src/unix/stream.rs @@ -0,0 +1,336 @@ +use alloc::{boxed::Box, sync::Arc}; +use core::{ + sync::atomic::{AtomicBool, Ordering}, + task::Context, +}; + +use async_trait::async_trait; +use axerrno::{AxError, AxResult}; +use axio::{Buf, BufMut}; +use axpoll::{IoEvents, PollSet, Pollable}; +use axsync::Mutex; +use ringbuf::{ + HeapCons, HeapProd, HeapRb, + traits::{Consumer, Observer, Producer, Split}, +}; + +use crate::{ + RecvOptions, SendOptions, Shutdown, + general::GeneralOptions, + options::{Configurable, GetSocketOption, SetSocketOption, UnixCredentials}, + unix::{Transport, TransportOps, UnixSocketAddr}, +}; + +const BUF_SIZE: usize = 64 * 1024; + +fn new_uni_channel() -> (HeapProd, HeapCons) { + let rb = HeapRb::new(BUF_SIZE); + rb.split() +} +fn new_channels(pid: u32) -> (Channel, Channel) { + let (client_tx, server_rx) = new_uni_channel(); + let (server_tx, client_rx) = new_uni_channel(); + let poll_update = Arc::new(PollSet::new()); + ( + Channel { + tx: client_tx, + rx: client_rx, + poll_update: poll_update.clone(), + peer_pid: pid, + }, + Channel { + tx: server_tx, + rx: server_rx, + poll_update, + peer_pid: pid, + }, + ) +} + +struct Channel { + tx: HeapProd, + rx: HeapCons, + // TODO: granularity + poll_update: Arc, + peer_pid: u32, +} + +pub struct Bind { + /// New connections are sent to this channel. + conn_tx: async_channel::Sender, + poll_new_conn: Arc, + pid: u32, +} +impl Bind { + fn connect(&self, local_addr: UnixSocketAddr, pid: u32) -> AxResult { + let (mut client_chan, mut server_chan) = new_channels(0); + client_chan.peer_pid = self.pid; + server_chan.peer_pid = pid; + self.conn_tx + .try_send(ConnRequest { + channel: server_chan, + addr: local_addr, + pid, + }) + .map_err(|_| AxError::ConnectionRefused)?; + self.poll_new_conn.wake(); + Ok(client_chan) + } +} + +struct ConnRequest { + channel: Channel, + addr: UnixSocketAddr, + pid: u32, +} + +pub struct StreamTransport { + channel: Mutex>, + conn_rx: Mutex, Arc)>>, + poll_state: PollSet, + general: GeneralOptions, + pid: u32, + rx_closed: AtomicBool, + tx_closed: AtomicBool, +} +impl StreamTransport { + pub fn new(pid: u32) -> Self { + StreamTransport::new_channel(None, pid) + } + + fn new_channel(channel: Option, pid: u32) -> Self { + StreamTransport { + channel: Mutex::new(channel), + conn_rx: Mutex::new(None), + poll_state: PollSet::new(), + general: GeneralOptions::default(), + pid, + rx_closed: AtomicBool::new(false), + tx_closed: AtomicBool::new(false), + } + } + + pub fn new_pair(pid: u32) -> (Self, Self) { + let (chan1, chan2) = new_channels(pid); + let transport1 = StreamTransport::new_channel(Some(chan1), pid); + let transport2 = StreamTransport::new_channel(Some(chan2), pid); + (transport1, transport2) + } +} + +impl Configurable for StreamTransport { + fn get_option_inner(&self, opt: &mut GetSocketOption) -> AxResult { + use GetSocketOption as O; + + if self.general.get_option_inner(opt)? { + return Ok(true); + } + + match opt { + O::SendBuffer(size) => { + **size = BUF_SIZE; + } + O::PassCredentials(_) => {} + O::PeerCredentials(cred) => { + let peer_pid = self + .channel + .lock() + .as_ref() + .map_or(self.pid, |chan| chan.peer_pid); + **cred = UnixCredentials::new(peer_pid); + } + _ => return Ok(false), + } + Ok(true) + } + + fn set_option_inner(&self, opt: SetSocketOption) -> AxResult { + use SetSocketOption as O; + + if self.general.set_option_inner(opt)? { + return Ok(true); + } + + match opt { + O::PassCredentials(_) => {} + _ => return Ok(false), + } + Ok(true) + } +} +#[async_trait] +impl TransportOps for StreamTransport { + fn bind(&self, slot: &super::BindSlot, _local_addr: &UnixSocketAddr) -> AxResult<()> { + let mut slot = slot.stream.lock(); + if slot.is_some() { + return Err(AxError::AddrInUse); + } + let mut guard = self.conn_rx.lock(); + if guard.is_some() { + return Err(AxError::InvalidInput); + } + let (tx, rx) = async_channel::unbounded(); + let poll = Arc::new(PollSet::new()); + *slot = Some(Bind { + conn_tx: tx, + poll_new_conn: poll.clone(), + pid: self.pid, + }); + *guard = Some((rx, poll)); + self.poll_state.wake(); + Ok(()) + } + + fn connect(&self, slot: &super::BindSlot, local_addr: &UnixSocketAddr) -> AxResult<()> { + let mut guard = self.channel.lock(); + if guard.is_some() { + return Err(AxError::AlreadyConnected); + } + *guard = Some( + slot.stream + .lock() + .as_ref() + .ok_or(AxError::NotConnected)? + .connect(local_addr.clone(), self.pid)?, + ); + self.poll_state.wake(); + Ok(()) + } + + async fn accept(&self) -> AxResult<(Transport, UnixSocketAddr)> { + let mut guard = self.conn_rx.lock(); + let Some((rx, _)) = guard.as_mut() else { + return Err(AxError::NotConnected); + }; + let ConnRequest { + channel, + addr: peer_addr, + pid, + } = rx.recv().await.map_err(|_| AxError::ConnectionReset)?; + Ok(( + Transport::Stream(StreamTransport::new_channel(Some(channel), pid)), + peer_addr, + )) + } + + fn send(&self, src: &mut impl Buf, options: SendOptions) -> AxResult { + if options.to.is_some() { + return Err(AxError::InvalidInput); + } + let size = src.remaining(); + let mut total = 0; + let non_blocking = self.general.nonblocking(); + self.general.send_poller(self, || { + let mut guard = self.channel.lock(); + let Some(chan) = guard.as_mut() else { + return Err(AxError::NotConnected); + }; + if !chan.tx.read_is_held() { + return Err(AxError::BrokenPipe); + } + + let count = { + let (left, right) = chan.tx.vacant_slices_mut(); + let mut count = src.read(unsafe { left.assume_init_mut() })?; + if count >= left.len() { + count += src.read(unsafe { right.assume_init_mut() })?; + } + unsafe { chan.tx.advance_write_index(count) }; + count + }; + total += count; + if count > 0 { + chan.poll_update.wake(); + } + + if count == size || non_blocking { + Ok(total) + } else { + Err(AxError::WouldBlock) + } + }) + } + + fn recv(&self, dst: &mut impl BufMut, _options: RecvOptions) -> AxResult { + self.general.recv_poller(self, || { + let mut guard = self.channel.lock(); + let Some(chan) = guard.as_mut() else { + return Err(AxError::NotConnected); + }; + + let count = { + let (left, right) = chan.rx.as_slices(); + let mut count = dst.write(left)?; + if count >= left.len() { + count += dst.write(right)?; + } + unsafe { chan.rx.advance_read_index(count) }; + count + }; + if count > 0 { + chan.poll_update.wake(); + Ok(count) + } else { + Err(AxError::WouldBlock) + } + }) + } + + fn shutdown(&self, how: Shutdown) -> AxResult<()> { + if how.has_read() { + self.rx_closed.store(true, Ordering::Release); + self.poll_state.wake(); + } + if how.has_write() { + self.tx_closed.store(true, Ordering::Release); + self.poll_state.wake(); + } + if self.rx_closed.load(Ordering::Acquire) && self.tx_closed.load(Ordering::Acquire) { + if let Some(chan) = self.channel.lock().take() { + chan.poll_update.wake(); + } + } + Ok(()) + } +} + +impl Pollable for StreamTransport { + fn poll(&self) -> IoEvents { + let mut events = IoEvents::empty(); + if let Some(chan) = self.channel.lock().as_ref() { + events.set( + IoEvents::IN, + !self.rx_closed.load(Ordering::Acquire) && chan.rx.occupied_len() > 0, + ); + events.set( + IoEvents::OUT, + !self.tx_closed.load(Ordering::Acquire) && chan.tx.vacant_len() > 0, + ); + } else if let Some((conn_tx, _)) = self.conn_rx.lock().as_ref() { + events.set(IoEvents::IN, conn_tx.len() > 0); + } + events.set(IoEvents::RDHUP, self.rx_closed.load(Ordering::Acquire)); + events + } + + fn register(&self, context: &mut Context<'_>, events: IoEvents) { + if let Some(chan) = self.channel.lock().as_ref() { + if events.intersects(IoEvents::IN | IoEvents::OUT) { + chan.poll_update.register(context.waker()); + } + } else if let Some((_, poll_new_conn)) = self.conn_rx.lock().as_ref() { + if events.contains(IoEvents::IN) { + poll_new_conn.register(context.waker()); + } + } + self.poll_state.register(context.waker()); + } +} + +impl Drop for StreamTransport { + fn drop(&mut self) { + if let Some(chan) = self.channel.lock().as_ref() { + chan.poll_update.wake(); + } + } +} diff --git a/modules/axnet/src/vsock.rs b/modules/axnet/src/vsock.rs new file mode 100644 index 0000000000..1048d4d494 --- /dev/null +++ b/modules/axnet/src/vsock.rs @@ -0,0 +1,134 @@ +// pub(crate) mod dgram; todo + +pub(crate) mod connection_manager; +pub(crate) mod stream; + +use core::task::Context; + +pub use axdriver::prelude::{VsockAddr, VsockConnId}; +use axerrno::{AxError, AxResult}; +use axio::{Buf, BufMut}; +use axpoll::{IoEvents, Pollable}; +use enum_dispatch::enum_dispatch; + +pub use self::stream::VsockStreamTransport; +use crate::{ + RecvOptions, SendOptions, Shutdown, Socket, SocketAddrEx, SocketOps, + options::{Configurable, GetSocketOption, SetSocketOption}, +}; + +/// Abstract transport trait for Unix sockets. +#[enum_dispatch] +pub trait VsockTransportOps: Configurable + Pollable + Send + Sync { + fn bind(&self, local_addr: VsockAddr) -> AxResult; + fn listen(&self) -> AxResult; + fn connect(&self, peer_addr: VsockAddr) -> AxResult; + fn accept(&self) -> AxResult<(VsockTransport, VsockAddr)>; + fn send(&self, src: &mut impl Buf, options: SendOptions) -> AxResult; + fn recv(&self, dst: &mut impl BufMut, options: RecvOptions<'_>) -> AxResult; + fn shutdown(&self, _how: Shutdown) -> AxResult; + fn local_addr(&self) -> AxResult>; + fn peer_addr(&self) -> AxResult>; +} + +#[enum_dispatch(Configurable, VsockTransportOps)] +pub enum VsockTransport { + Stream(VsockStreamTransport), + // Dgram(VsockDgramVsockTransport), +} + +impl Pollable for VsockTransport { + fn poll(&self) -> IoEvents { + match self { + VsockTransport::Stream(stream) => stream.poll(), + // VsockTransport::Dgram(dgram) => dgram.poll(), + } + } + + fn register(&self, context: &mut core::task::Context<'_>, events: IoEvents) { + match self { + VsockTransport::Stream(stream) => stream.register(context, events), + // VsockTransport::Dgram(dgram) => dgram.register(context, events), + } + } +} + +/// A network socket using the vsock protocol. +pub struct VsockSocket { + transport: VsockTransport, +} + +impl VsockSocket { + pub fn new(transport: impl Into) -> Self { + Self { + transport: transport.into(), + } + } +} + +impl Configurable for VsockSocket { + fn get_option_inner(&self, opt: &mut GetSocketOption) -> AxResult { + self.transport.get_option_inner(opt) + } + + fn set_option_inner(&self, opt: SetSocketOption) -> AxResult { + self.transport.set_option_inner(opt) + } +} + +impl SocketOps for VsockSocket { + fn bind(&self, local_addr: SocketAddrEx) -> AxResult { + let local_addr = local_addr.into_vsock()?; + self.transport.bind(local_addr) + } + + fn connect(&self, remote_addr: SocketAddrEx) -> AxResult { + let remote_addr = remote_addr.into_vsock()?; + self.transport.connect(remote_addr) + } + + fn listen(&self) -> AxResult { + self.transport.listen() + } + + fn accept(&self) -> AxResult { + self.transport.accept().map(|(transport, _addr)| { + let socket = VsockSocket::new(transport); + Socket::Vsock(socket) + }) + } + + fn send(&self, src: &mut impl Buf, options: SendOptions) -> AxResult { + self.transport.send(src, options) + } + + fn recv(&self, dst: &mut impl BufMut, options: RecvOptions<'_>) -> AxResult { + self.transport.recv(dst, options) + } + + fn local_addr(&self) -> AxResult { + Ok(SocketAddrEx::Vsock( + self.transport.local_addr()?.ok_or(AxError::NotFound)?, + )) + } + + fn peer_addr(&self) -> AxResult { + Ok(SocketAddrEx::Vsock( + self.transport.peer_addr()?.ok_or(AxError::NotFound)?, + )) + } + + fn shutdown(&self, how: Shutdown) -> AxResult { + self.transport.shutdown(how) + } +} + +impl Pollable for VsockSocket { + fn poll(&self) -> IoEvents { + self.transport.poll() + } + + fn register(&self, context: &mut Context<'_>, events: IoEvents) { + self.transport.register(context, events); + } +} diff --git a/modules/axnet/src/vsock/connection_manager.rs b/modules/axnet/src/vsock/connection_manager.rs new file mode 100644 index 0000000000..78e8283cc9 --- /dev/null +++ b/modules/axnet/src/vsock/connection_manager.rs @@ -0,0 +1,517 @@ +use alloc::{collections::BTreeMap, sync::Arc}; + +use axerrno::{AxError, AxResult, ax_bail}; +use axpoll::PollSet; +use axsync::Mutex; +use ringbuf::{HeapCons, HeapProd, HeapRb, traits::*}; + +use super::{VsockAddr, VsockConnId}; + +pub const VSOCK_RX_BUFFER_SIZE: usize = 64 * 1024; // 64KB receive buffer +const VSOCK_ACCEPT_QUEUE_SIZE: usize = 128; // accept queue size + +/// connection states +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnectionState { + Idle, + Listening, + Connecting, + Connected, + Closed, +} + +/// Connection +pub struct Connection { + state: ConnectionState, + local_addr: VsockAddr, + peer_addr: Option, + + /// recv buffer read from driver + rx_producer: HeapProd, + rx_consumer: HeapCons, + + /// Waker lists + rx_wakers: PollSet, + connect_wakers: PollSet, + + /// closed flags + rx_closed: bool, + tx_closed: bool, + + /// statistics + rx_bytes: usize, // received bytes count + tx_bytes: usize, // sent bytes count + dropped_bytes: usize, // dropped bytes count +} + +impl Connection { + fn new(local_addr: VsockAddr, peer_addr: Option, state: ConnectionState) -> Self { + let rb = HeapRb::::new(VSOCK_RX_BUFFER_SIZE); + let (rx_producer, rx_consumer) = rb.split(); + Self { + state, + local_addr, + peer_addr, + rx_producer, + rx_consumer, + rx_wakers: PollSet::new(), + connect_wakers: PollSet::new(), + rx_closed: false, + tx_closed: false, + rx_bytes: 0, + tx_bytes: 0, + dropped_bytes: 0, + } + } + + /// Register a waker for transmit events + pub fn register_accept_poll(&mut self, context: &mut core::task::Context<'_>) { + // found listen queue + let manager = VSOCK_CONN_MANAGER.lock(); + let queue = manager + .get_listen_queue(self.local_addr.port) + .expect("listen queue not found"); + drop(manager); + queue.lock().register_poll(context); + } + + /// Register a waker for receive Events + pub fn register_rx_poll(&mut self, context: &mut core::task::Context<'_>) { + self.rx_wakers.register(context.waker()); + } + + /// Register a waker for connect Events + pub fn register_connect_poll(&mut self, _context: &mut core::task::Context<'_>) { + self.connect_wakers.register(_context.waker()); + } + + /// Get the free space in the receive buffer + #[inline] + pub fn rx_buffer_free(&self) -> usize { + self.rx_producer.vacant_len() + } + + /// Get the used space in the receive buffer + #[inline] + pub fn rx_buffer_used(&self) -> usize { + self.rx_consumer.occupied_len() + } + + /// push data into the receive buffer + pub fn push_rx_data(&mut self, data: &[u8]) -> usize { + let available = self.rx_buffer_free(); + let to_write = data.len().min(available); + + if to_write > 0 { + let written = self.rx_producer.push_slice(&data[..to_write]); + self.rx_bytes += written; + + if written < data.len() { + let dropped = data.len() - written; + self.dropped_bytes += dropped; + info!( + "Vsock connection {:?} rx buffer full, dropped {} bytes", + (self.local_addr, self.peer_addr), + dropped + ); + } + written + } else { + self.dropped_bytes += data.len(); + info!( + "Vsock connection {:?} rx buffer full, dropped {} bytes", + (self.local_addr, self.peer_addr), + data.len() + ); + 0 + } + } + + #[inline] + pub fn rx_slices(&self) -> (&[u8], &[u8]) { + self.rx_consumer.as_slices() + } + + #[inline] + pub fn advance_rx_read(&mut self, count: usize) { + unsafe { + self.rx_consumer.advance_read_index(count); + } + } + + #[inline] + pub fn rx_iter(&self) -> impl Iterator { + self.rx_consumer.iter() + } + + #[inline] + pub fn add_tx_bytes(&mut self, count: usize) { + self.tx_bytes += count; + } + + #[inline] + pub fn wake_rx(&mut self) { + self.rx_wakers.wake(); + } + + #[inline] + pub fn wake_connect(&mut self) { + self.connect_wakers.wake(); + } + + #[inline] + pub fn local_addr(&self) -> VsockAddr { + self.local_addr + } + + #[inline] + pub fn peer_addr(&self) -> Option { + self.peer_addr + } + + #[inline] + pub fn set_state(&mut self, state: ConnectionState) { + self.state = state; + } + + #[inline] + pub fn state(&self) -> ConnectionState { + self.state + } + + #[inline] + pub fn rx_closed(&self) -> bool { + self.rx_closed + } + + #[inline] + pub fn tx_closed(&self) -> bool { + self.tx_closed + } + + #[inline] + pub fn set_rx_closed(&mut self, closed: bool) { + self.rx_closed = closed; + } + + #[inline] + pub fn set_tx_closed(&mut self, closed: bool) { + self.tx_closed = closed; + } +} + +/// A fixed-size accept queue +pub struct AcceptQueue { + producer: ringbuf::HeapProd, + consumer: ringbuf::HeapCons, +} + +impl AcceptQueue { + pub fn new() -> Self { + let rb = HeapRb::::new(VSOCK_ACCEPT_QUEUE_SIZE); + let (producer, consumer) = rb.split(); + Self { producer, consumer } + } + + pub fn is_empty(&self) -> bool { + self.consumer.is_empty() + } + + pub fn push(&mut self, conn_id: VsockConnId) -> AxResult<()> { + match self.producer.try_push(conn_id) { + Ok(_) => Ok(()), + Err(_) => ax_bail!(ResourceBusy, "accept queue full"), + } + } + + pub fn pop(&mut self) -> Option { + self.consumer.try_pop() + } +} + +/// listen queue +pub struct ListenQueue { + pub accept_queue: AcceptQueue, + pub wakers: PollSet, + pub local_addr: VsockAddr, +} + +impl ListenQueue { + pub fn new(local_addr: VsockAddr) -> Self { + Self { + accept_queue: AcceptQueue::new(), + wakers: PollSet::new(), + local_addr, + } + } + + pub fn wake(&mut self) { + self.wakers.wake(); + } + + pub fn register_poll(&mut self, context: &mut core::task::Context<'_>) { + self.wakers.register(context.waker()); + } +} + +/// Global connection manager +pub struct VsockConnectionManager { + connections: BTreeMap>>, + listen_queues: BTreeMap>>, + next_ephemeral_port: u32, +} + +impl VsockConnectionManager { + const EPHEMERAL_PORT_END: u32 = 0xffff; + const EPHEMERAL_PORT_START: u32 = 0xc000; + + pub const fn new() -> Self { + Self { + connections: BTreeMap::new(), + listen_queues: BTreeMap::new(), + next_ephemeral_port: Self::EPHEMERAL_PORT_START, + } + } + + /// Get listen queue from specified port + pub fn get_listen_queue(&self, port: u32) -> Option>> { + self.listen_queues.get(&port).cloned() + } + + /// allocate an ephemeral port + pub fn allocate_port(&mut self) -> AxResult { + let start = self.next_ephemeral_port; + loop { + let port = self.next_ephemeral_port; + self.next_ephemeral_port = if port >= Self::EPHEMERAL_PORT_END { + Self::EPHEMERAL_PORT_START + } else { + port + 1 + }; + + // check if port is in use by listen queue + if !self.listen_queues.contains_key(&port) { + // check if port is in use by existing connections + let port_in_use = self.connections.keys().any(|id| id.local_port == port); + if !port_in_use { + return Ok(port); + } + } + + if self.next_ephemeral_port == start { + ax_bail!(AddrInUse, "no available ports"); + } + } + } + + /// create a listen queue + pub fn listen(&mut self, local_addr: VsockAddr) -> AxResult<()> { + if self.listen_queues.contains_key(&local_addr.port) { + ax_bail!(AddrInUse, "port already in use"); + } + + let queue = Arc::new(Mutex::new(ListenQueue::new(local_addr))); + self.listen_queues.insert(local_addr.port, queue); + Ok(()) + } + + /// stop listening + pub fn unlisten(&mut self, port: u32) { + self.listen_queues.remove(&port); + debug!("Vsock unlisten on port {}", port); + } + + /// check if port accept + pub fn can_accept(&self, port: u32) -> bool { + self.listen_queues + .get(&port) + .map(|q| !q.lock().accept_queue.is_empty()) + .unwrap_or(false) + } + + /// accept a connection + pub fn accept(&mut self, port: u32) -> AxResult<(VsockConnId, VsockAddr)> { + let queue = self.listen_queues.get(&port).ok_or(AxError::InvalidInput)?; + + let conn_id = queue.lock().accept_queue.pop().ok_or(AxError::WouldBlock)?; + + let conn = self.connections.get(&conn_id).ok_or(AxError::NotFound)?; + + let peer_addr = conn.lock().peer_addr.ok_or(AxError::NotFound)?; + + debug!("Accepted connection: {:?} from {:?}", conn_id, peer_addr); + Ok((conn_id, peer_addr)) + } + + /// create a new connection + pub fn create_connection( + &mut self, + conn_id: VsockConnId, + local_addr: VsockAddr, + peer_addr: Option, + state: ConnectionState, + ) -> Arc> { + let conn = Connection::new(local_addr, peer_addr, state); + let conn = Arc::new(Mutex::new(conn)); + if self.connections.contains_key(&conn_id) { + info!("Connection {:?} already exists, overwriting", conn_id); + } else { + crate::device::start_vsock_poll(); + } + self.connections.insert(conn_id, conn.clone()); + debug!( + "Created connection {:?}: local={:?}, peer={:?}", + conn_id, local_addr, peer_addr + ); + conn + } + + /// get a connection by id + pub fn get_connection(&self, conn_id: VsockConnId) -> Option>> { + self.connections.get(&conn_id).cloned() + } + + /// remove a connection + pub fn remove_connection(&mut self, conn_id: VsockConnId) { + if let Some(conn) = self.connections.remove(&conn_id) { + let conn = conn.lock(); + crate::device::stop_vsock_poll(); + debug!( + "Removed connection {:?}: rx={} bytes, tx={} bytes, dropped={} bytes", + conn_id, conn.rx_bytes, conn.tx_bytes, conn.dropped_bytes + ); + } + } + + /// handle a new connection request (by driver event) + pub fn on_connection_request(&mut self, conn_id: VsockConnId) -> AxResult<()> { + let queue = self + .listen_queues + .get(&conn_id.local_port) + .ok_or(AxError::NotFound)? + .clone(); + + let local_addr = queue.lock().local_addr; + + // check if connection already exists + if self.connections.contains_key(&conn_id) { + warn!("Connection {:?} already exists, ignoring request", conn_id); + return Ok(()); + } + + // create new connection + self.create_connection( + conn_id, + local_addr, + Some(conn_id.peer_addr), + ConnectionState::Connected, + ); + + // 加入 accept 队列 + let mut queue_guard = queue.lock(); + if let Err(_) = queue_guard.accept_queue.push(conn_id) { + info!( + "Accept queue full for port {}, dropping connection from {:?}", + conn_id.local_port, conn_id.peer_addr + ); + // full -- remove the connection + drop(queue_guard); + self.remove_connection(conn_id); + return Err(AxError::ResourceBusy); + } + + queue_guard.wake(); + drop(queue_guard); + + trace!( + "New connection request from {:?} on port {}", + conn_id.peer_addr, conn_id.local_port + ); + Ok(()) + } + + /// handle data received (by driver event) + pub fn on_data_received(&mut self, conn_id: VsockConnId, data: &[u8]) -> AxResult<()> { + let conn = self + .connections + .get(&conn_id) + .ok_or(AxError::NotFound)? + .clone(); + + let mut conn_guard = conn.lock(); + let written = conn_guard.push_rx_data(data); + if written > 0 { + conn_guard.wake_rx(); + } + + trace!( + "Received {} bytes for connection {:?} (written={}, buffer_used={}/{})", + data.len(), + conn_id, + written, + conn_guard.rx_buffer_used(), + VSOCK_RX_BUFFER_SIZE + ); + Ok(()) + } + + /// handle disconnection (by driver event) + pub fn on_disconnected(&mut self, conn_id: VsockConnId) -> AxResult<()> { + if let Some(conn) = self.connections.get(&conn_id) { + let mut conn_guard = conn.lock(); + conn_guard.state = ConnectionState::Closed; + conn_guard.rx_closed = true; + conn_guard.tx_closed = true; + conn_guard.wake_rx(); + trace!("Connection {:?} disconnected", conn_id); + } + Ok(()) + } + + /// handle connected event (by driver event) + pub fn on_connected(&mut self, conn_id: VsockConnId) -> AxResult<()> { + if let Some(conn) = self.connections.get(&conn_id) { + let mut conn_guard = conn.lock(); + conn_guard.state = ConnectionState::Connected; + conn_guard.wake_connect(); + trace!("Connection {:?} established", conn_id); + } + Ok(()) + } + + /// statistics + #[allow(dead_code)] + pub fn get_stats(&self) -> VsockStats { + VsockStats { + total_connections: self.connections.len(), + listening_ports: self.listen_queues.len(), + total_rx_bytes: self.connections.values().map(|c| c.lock().rx_bytes).sum(), + total_tx_bytes: self.connections.values().map(|c| c.lock().tx_bytes).sum(), + total_dropped_bytes: self + .connections + .values() + .map(|c| c.lock().dropped_bytes) + .sum(), + } + } +} + +/// Vsock statistics +#[allow(dead_code)] +#[derive(Debug, Clone)] +pub struct VsockStats { + pub total_connections: usize, + pub listening_ports: usize, + pub total_rx_bytes: usize, + pub total_tx_bytes: usize, + pub total_dropped_bytes: usize, +} + +pub static VSOCK_CONN_MANAGER: Mutex = + Mutex::new(VsockConnectionManager::new()); + +/// for debug +#[allow(dead_code)] +pub fn get_vsock_stats() -> VsockStats { + VSOCK_CONN_MANAGER.lock().get_stats() +} diff --git a/modules/axnet/src/vsock/stream.rs b/modules/axnet/src/vsock/stream.rs new file mode 100644 index 0000000000..89b68a9400 --- /dev/null +++ b/modules/axnet/src/vsock/stream.rs @@ -0,0 +1,379 @@ +use alloc::sync::Arc; +use core::task::Context; + +use axerrno::{AxError, AxResult, ax_bail, ax_err_type}; +use axio::{Buf, BufMut}; +use axpoll::{IoEvents, Pollable}; +use axsync::Mutex; + +use super::connection_manager::*; +use crate::{ + RecvFlags, RecvOptions, SendOptions, Shutdown, + general::GeneralOptions, + options::{Configurable, GetSocketOption, SetSocketOption}, + state::*, + vsock::{VsockAddr, VsockConnId, VsockTransport, VsockTransportOps}, +}; + +pub struct VsockStreamTransport { + conn_id: Mutex>, + connection: Mutex>>>, + state: StateLock, + general: GeneralOptions, +} + +impl VsockStreamTransport { + pub fn new() -> Self { + Self { + conn_id: Mutex::new(None), + connection: Mutex::new(None), + state: StateLock::new(State::Idle), + general: GeneralOptions::new(), + } + } + + fn get_connection(&self) -> AxResult>> { + self.connection.lock().clone().ok_or(AxError::NotConnected) + } +} + +impl Configurable for VsockStreamTransport { + fn get_option_inner(&self, opt: &mut GetSocketOption) -> AxResult { + self.general.get_option_inner(opt) + } + + fn set_option_inner(&self, opt: SetSocketOption) -> AxResult { + self.general.set_option_inner(opt) + } +} + +impl VsockTransportOps for VsockStreamTransport { + fn bind(&self, mut local_addr: VsockAddr) -> AxResult<()> { + self.state + .lock(State::Idle) + .map_err(|_| ax_err_type!(InvalidInput, "already bound"))? + .transit(State::Idle, || { + let mut manager = VSOCK_CONN_MANAGER.lock(); + if local_addr.port == 0 { + local_addr.port = manager.allocate_port()?; + } + let conn_id = VsockConnId::listening(local_addr.port); + let conn = + manager.create_connection(conn_id, local_addr, None, ConnectionState::Idle); + + *self.conn_id.lock() = Some(conn_id); + *self.connection.lock() = Some(conn); + trace!("Vsock binding to {:?}", local_addr); + Ok(()) + })?; + Ok(()) + } + + fn listen(&self) -> AxResult<()> { + let guard = self + .state + .lock(State::Idle) + .map_err(|_| ax_err_type!(InvalidInput, "invalid state for listen"))?; + + guard.transit(State::Listening, || { + let conn = self.get_connection()?; + let local_addr = conn.lock().local_addr(); + + // register in the global listen table + VSOCK_CONN_MANAGER.lock().listen(local_addr)?; + crate::device::vsock_listen(local_addr)?; + // set state + conn.lock().set_state(ConnectionState::Listening); + trace!("Vsock listening on {:?}", local_addr); + Ok(()) + }) + } + + fn accept(&self) -> AxResult<(VsockTransport, VsockAddr)> { + if self.state.get() != State::Listening { + ax_bail!(InvalidInput, "not listening"); + } + + let conn = self.get_connection()?; + let local_port = conn.lock().local_addr().port; + + // wait for connection + self.general.recv_poller(self, || { + let mut manager = VSOCK_CONN_MANAGER.lock(); + + if !manager.can_accept(local_port) { + return Err(AxError::WouldBlock); + } + + let (conn_id, peer_addr) = manager.accept(local_port)?; + let conn = manager.get_connection(conn_id).ok_or(AxError::NotFound)?; + + // create new VsockStreamTransport + let new_transport = VsockStreamTransport { + conn_id: Mutex::new(Some(conn_id)), + connection: Mutex::new(Some(conn)), + state: StateLock::new(State::Connected), + general: GeneralOptions::default(), + }; + + Ok((VsockTransport::Stream(new_transport), peer_addr)) + }) + } + + fn connect(&self, peer_addr: VsockAddr) -> AxResult<()> { + let guard = self.state.lock(State::Idle).map_err(|state| match state { + State::Idle => unreachable!(), + State::Listening => ax_err_type!(InvalidInput, "already listening"), + State::Connecting => ax_err_type!(InProgress), + State::Connected => ax_err_type!(AlreadyConnected), + _ => ax_err_type!(AlreadyConnected), + })?; + + guard.transit(State::Connecting, || { + let mut manager = VSOCK_CONN_MANAGER.lock(); + let existing_conn = self.connection.lock(); + + // get local address + let local_port = if let Some(conn) = existing_conn.as_ref() { + let conn_guard = conn.lock(); + match conn_guard.state() { + ConnectionState::Idle => { + // already bound but not connected, reuse the port + conn_guard.local_addr().port + } + _ => { + // should not happen due to state check above + ax_bail!(InvalidInput, "already connected or listening"); + } + } + } else { + manager.allocate_port()? + }; + drop(existing_conn); + + let local_addr = VsockAddr { + cid: crate::device::vsock_guest_cid()?, + port: local_port, + }; + + // create connection + let conn_id = VsockConnId { + peer_addr, + local_port, + }; + let conn = manager.create_connection( + conn_id, + local_addr, + Some(peer_addr), + ConnectionState::Connecting, + ); + + *self.conn_id.lock() = Some(conn_id); + *self.connection.lock() = Some(conn.clone()); + + drop(manager); + + // driver connect + crate::device::vsock_connect(conn_id)?; + debug!("Vsock connecting from {} to {:?}", local_port, peer_addr); + Ok(()) + })?; + + // wait for connection established + self.general.send_poller(self, || { + let conn = self.get_connection()?; + let state = conn.lock().state(); + match state { + ConnectionState::Connected => Ok(()), + ConnectionState::Connecting => Err(AxError::WouldBlock), + _ => Err(ax_err_type!(ConnectionRefused)), + } + }) + } + + fn send(&self, src: &mut impl Buf, _options: SendOptions) -> AxResult { + let conn = self.get_connection()?; + let conn_guard = conn.lock(); + + if conn_guard.state() != ConnectionState::Connected { + return Err(AxError::NotConnected); + } + + if conn_guard.tx_closed() { + return Err(AxError::NotConnected); + } + + let conn_id = self.conn_id.lock().ok_or(AxError::NotConnected)?; + drop(conn_guard); + + // now virtio-driver only support non-blocking send + let result = src.consume(|chunk| crate::device::vsock_send(conn_id, chunk)); + conn.lock().add_tx_bytes(result.unwrap_or(0)); + result + } + + fn recv(&self, dst: &mut impl BufMut, options: RecvOptions) -> AxResult { + let conn = self.get_connection()?; + + self.general.recv_poller(self, || { + let mut conn_guard = conn.lock(); + + if conn_guard.rx_closed() && conn_guard.rx_buffer_used() == 0 { + return Ok(0); // EOF + } + + // should allow read when connection is closed, to read remaining data + if !matches!( + conn_guard.state(), + ConnectionState::Connected | ConnectionState::Closed + ) { + return Err(AxError::NotConnected); + } + + if conn_guard.rx_buffer_used() == 0 { + return Err(AxError::WouldBlock); + } + + let count = if options.flags.contains(RecvFlags::PEEK) { + // Peek mode: not remove data from buffer + let available = conn_guard.rx_buffer_used(); + let to_read = dst.remaining_mut().min(available); + let data: alloc::vec::Vec = + conn_guard.rx_iter().take(to_read).copied().collect(); + dst.write(&data)? + } else { + // Normal mode: remove data from buffer + let (left, right) = conn_guard.rx_slices(); + let mut count = dst.write(left)?; + + if count >= left.len() && !right.is_empty() { + count += dst.write(right)?; + } + conn_guard.advance_rx_read(count); + count + }; + + if count > 0 { + trace!( + "Recv {} bytes from connection (buffer_remaining={}/{})", + count, + conn_guard.rx_buffer_used(), + VSOCK_RX_BUFFER_SIZE + ); + Ok(count) + } else { + return Err(AxError::WouldBlock); + } + }) + } + + fn shutdown(&self, how: Shutdown) -> AxResult<()> { + let conn = self.get_connection()?; + let mut conn = conn.lock(); + + if how.has_read() { + conn.set_rx_closed(true); + } + + if how.has_write() { + conn.set_tx_closed(true); + } + + if let Some(conn_id) = *self.conn_id.lock() { + if conn.state() == ConnectionState::Connected { + crate::device::vsock_disconnect(conn_id)?; + } else if conn.state() == ConnectionState::Listening { + VSOCK_CONN_MANAGER.lock().unlisten(conn_id.local_port); + } + } + conn.set_state(ConnectionState::Closed); + Ok(()) + } + + fn local_addr(&self) -> AxResult> { + Ok(self + .get_connection() + .ok() + .map(|conn| conn.lock().local_addr())) + } + + fn peer_addr(&self) -> AxResult> { + Ok(self + .get_connection() + .ok() + .and_then(|conn| conn.lock().peer_addr())) + } +} + +impl Pollable for VsockStreamTransport { + fn poll(&self) -> IoEvents { + let Ok(conn) = self.get_connection() else { + return IoEvents::empty(); + }; + + let conn = conn.lock(); + let mut events = IoEvents::empty(); + + match conn.state() { + ConnectionState::Listening => { + // if there is a pending connection, set IN + if let Some(conn_id) = *self.conn_id.lock() { + events.set( + IoEvents::IN, + VSOCK_CONN_MANAGER.lock().can_accept(conn_id.local_port), + ); + } + } + ConnectionState::Connected | ConnectionState::Closed => { + events.set(IoEvents::IN, conn.rx_buffer_used() > 0 || conn.rx_closed()); + events.set(IoEvents::OUT, !conn.tx_closed()); + } + ConnectionState::Connecting => { + // if connected, set OUT + events.set(IoEvents::OUT, conn.state() == ConnectionState::Connected); + } + _ => {} + } + events.set(IoEvents::RDHUP, conn.rx_closed()); + events + } + + fn register(&self, context: &mut Context<'_>, events: IoEvents) { + if let Ok(conn) = self.get_connection() { + let mut conn = conn.lock(); + match conn.state() { + ConnectionState::Listening => { + if events.contains(IoEvents::IN) { + conn.register_accept_poll(context); + } + } + ConnectionState::Connected => { + if events.contains(IoEvents::IN) { + conn.register_rx_poll(context); + } + if events.contains(IoEvents::OUT) { + warn!( + "VsockStreamTransport: OUT event on connected socket is not supported" + ); + } + } + ConnectionState::Connecting => { + if events.contains(IoEvents::OUT) { + conn.register_connect_poll(context); + } + } + _ => {} + } + } + } +} + +impl Drop for VsockStreamTransport { + fn drop(&mut self) { + let _ = self.shutdown(Shutdown::Both); + + if let Some(conn_id) = *self.conn_id.lock() { + VSOCK_CONN_MANAGER.lock().remove_connection(conn_id); + } + } +} diff --git a/modules/axnet/src/wrapper.rs b/modules/axnet/src/wrapper.rs new file mode 100644 index 0000000000..042ae363b3 --- /dev/null +++ b/modules/axnet/src/wrapper.rs @@ -0,0 +1,80 @@ +use alloc::vec; + +use axerrno::{AxError, AxResult}; +use axsync::Mutex; +use event_listener::Event; +use smoltcp::{ + iface::{SocketHandle, SocketSet}, + socket::{AnySocket, Socket}, + wire::IpAddress, +}; + +pub(crate) struct SocketSetWrapper<'a> { + pub inner: Mutex>, + pub new_socket: Event, +} + +impl<'a> SocketSetWrapper<'a> { + pub fn new() -> Self { + Self { + inner: Mutex::new(SocketSet::new(vec![])), + new_socket: Event::new(), + } + } + + pub fn add>(&self, socket: T) -> SocketHandle { + let handle = self.inner.lock().add(socket); + debug!("socket {}: created", handle); + self.new_socket.notify(1); + handle + } + + pub fn with_socket, R, F>(&self, handle: SocketHandle, f: F) -> R + where + F: FnOnce(&T) -> R, + { + let set = self.inner.lock(); + let socket = set.get(handle); + f(socket) + } + + pub fn with_socket_mut, R, F>(&self, handle: SocketHandle, f: F) -> R + where + F: FnOnce(&mut T) -> R, + { + let mut set = self.inner.lock(); + let socket = set.get_mut(handle); + f(socket) + } + + pub fn bind_check(&self, addr: IpAddress, port: u16) -> AxResult { + if port == 0 { + return Ok(()); + } + + // TODO(mivik): optimize + let mut sockets = self.inner.lock(); + for (_, socket) in sockets.iter_mut() { + match socket { + Socket::Tcp(s) => { + let local_addr = s.get_bound_endpoint(); + if local_addr.addr == Some(addr) && local_addr.port == port { + return Err(AxError::AddrInUse); + } + } + Socket::Udp(s) => { + if s.endpoint().addr == Some(addr) && s.endpoint().port == port { + return Err(AxError::AddrInUse); + } + } + _ => continue, + }; + } + Ok(()) + } + + pub fn remove(&self, handle: SocketHandle) { + self.inner.lock().remove(handle); + debug!("socket {}: destroyed", handle); + } +} diff --git a/modules/axruntime/Cargo.toml b/modules/axruntime/Cargo.toml index 01e409463f..7dca670cc4 100644 --- a/modules/axruntime/Cargo.toml +++ b/modules/axruntime/Cargo.toml @@ -27,6 +27,7 @@ input = ["dep:axdriver", "dep:axinput"] vsock = ["net", "dep:axdriver"] rtc = [] driver-dyn = ["axdriver/dyn"] +watchdog = ["dep:axwatchdog"] [dependencies] axalloc = { workspace = true, optional = true } @@ -43,6 +44,7 @@ axmm = { workspace = true, optional = true } axnet = { workspace = true, optional = true } axplat = { workspace = true } axtask = { workspace = true, optional = true } +axwatchdog = { workspace = true, optional = true } chrono.workspace = true crate_interface.workspace = true ctor_bare.workspace = true diff --git a/modules/axruntime/src/lib.rs b/modules/axruntime/src/lib.rs index 8fedf258c4..fe67dd5e74 100644 --- a/modules/axruntime/src/lib.rs +++ b/modules/axruntime/src/lib.rs @@ -227,6 +227,19 @@ pub fn rust_main(cpu_id: usize, arg: usize) -> ! { info!("Primary CPU {cpu_id} init OK."); INITED_CPUS.fetch_add(1, Ordering::Release); + #[cfg(feature = "watchdog")] + { + axtask::register_timer_callback(|_| { + let now_ns = axhal::time::monotonic_time_nanos(); + let cpu_id = axhal::percpu::this_cpu_id(); + axwatchdog::timer_tick(cpu_id); + if axwatchdog::check_softlockup(cpu_id, now_ns) != axwatchdog::CpuHealth::Healthy { + axtask::show_global_task_queue(cpu_id); + panic!("Softlockup detected on CPU {}", cpu_id); + } + }); + } + while !is_init_ok() { core::hint::spin_loop(); } diff --git a/modules/axruntime/src/mp.rs b/modules/axruntime/src/mp.rs index 99fd1ad65c..ede6d38179 100644 --- a/modules/axruntime/src/mp.rs +++ b/modules/axruntime/src/mp.rs @@ -61,6 +61,19 @@ pub fn rust_main_secondary(cpu_id: usize) -> ! { #[cfg(feature = "irq")] axhal::asm::enable_irqs(); + #[cfg(feature = "watchdog")] + { + axtask::register_timer_callback(|_| { + let now_ns = axhal::time::monotonic_time_nanos(); + let cpu_id = axhal::percpu::this_cpu_id(); + axwatchdog::timer_tick(cpu_id); + if axwatchdog::check_softlockup(cpu_id, now_ns) != axwatchdog::CpuHealth::Healthy { + axtask::show_global_task_queue(cpu_id); + panic!("Softlockup detected on CPU {}", cpu_id); + } + }); + } + #[cfg(all(feature = "tls", not(feature = "multitask")))] super::init_tls(); diff --git a/modules/axtask/Cargo.toml b/modules/axtask/Cargo.toml index 84767b44d9..1173c20497 100644 --- a/modules/axtask/Cargo.toml +++ b/modules/axtask/Cargo.toml @@ -38,12 +38,15 @@ sched-cfs = ["multitask", "preempt"] test = ["percpu?/sp-naive"] +watchdog = ["dep:axwatchdog"] + [dependencies] axconfig = { workspace = true, optional = true } axerrno.workspace = true axhal.workspace = true axpoll = { workspace = true, optional = true } axsched = { version = "0.3", optional = true } +axwatchdog = { workspace = true, optional = true } cfg-if.workspace = true cpumask = { version = "0.1", optional = true } crate_interface = { workspace = true, optional = true } diff --git a/modules/axtask/src/api.rs b/modules/axtask/src/api.rs index 33fe68fdfc..f162e885d5 100644 --- a/modules/axtask/src/api.rs +++ b/modules/axtask/src/api.rs @@ -240,6 +240,19 @@ pub fn exit(exit_code: i32) -> ! { current_run_queue::().exit_current(exit_code) } +/// Print all tasks in the global task queue of the specified CPU. +#[cfg(feature = "watchdog")] +pub fn show_global_task_queue(cpu_id: usize) { + for weaktask in crate::run_queue::get_global_task_queue(cpu_id) + .lock() + .iter() + { + if let Some(task) = weaktask.upgrade() { + warn!("cpu_id: {}, {:?}", cpu_id, task.inner()); + } + } +} + /// The idle task routine. /// /// It runs an infinite loop that keeps calling [`yield_now()`]. diff --git a/modules/axtask/src/future/time.rs b/modules/axtask/src/future/time.rs index d849b65ce8..e75b759d35 100644 --- a/modules/axtask/src/future/time.rs +++ b/modules/axtask/src/future/time.rs @@ -8,7 +8,7 @@ use core::{ use axerrno::AxError; use axhal::time::{TimeValue, wall_time}; -use futures_util::{FutureExt, future::FusedFuture, select_biased}; +use futures_util::{FutureExt, select_biased}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] struct TimerKey { @@ -94,46 +94,33 @@ fn with_current(f: impl FnOnce(&mut TimerRuntime) -> R) -> R { /// Future returned by `sleep` and `sleep_until`. #[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct TimerFuture(Option); +pub struct TimerFuture(TimerKey); impl Future for TimerFuture { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let Some(key) = &self.0 else { - return Poll::Ready(()); - }; - let res = with_current(|r| r.poll(key, cx)); - if res.is_ready() { - self.get_mut().0 = None; - } - res - } -} - -impl FusedFuture for TimerFuture { - fn is_terminated(&self) -> bool { - self.0.is_none() + with_current(|r| r.poll(&self.0, cx)) } } impl Drop for TimerFuture { fn drop(&mut self) { - if let Some(key) = &self.0 { - with_current(|r| r.cancel(key)); - } + with_current(|r| r.cancel(&self.0)); } } /// Waits until `duration` has elapsed. -pub fn sleep(duration: Duration) -> TimerFuture { - sleep_until(wall_time() + duration) +pub async fn sleep(duration: Duration) { + sleep_until(wall_time() + duration).await } /// Waits until `deadline` is reached. -pub fn sleep_until(deadline: TimeValue) -> TimerFuture { +pub async fn sleep_until(deadline: TimeValue) { let key = with_current(|r| r.add(deadline)); - TimerFuture(key) + if let Some(key) = key { + TimerFuture(key).await; + } } /// Error returned by [`timeout`] and [`timeout_at`]. @@ -174,7 +161,7 @@ pub async fn timeout_at( if let Some(deadline) = deadline { select_biased! { res = f.into_future().fuse() => Ok(res), - _ = sleep_until(deadline) => Err(Elapsed(())), + _ = sleep_until(deadline).fuse() => Err(Elapsed(())), } } else { Ok(f.await) diff --git a/modules/axtask/src/run_queue.rs b/modules/axtask/src/run_queue.rs index 576d7677f0..6de9521a8d 100644 --- a/modules/axtask/src/run_queue.rs +++ b/modules/axtask/src/run_queue.rs @@ -22,6 +22,9 @@ use crate::{ task::{CurrentTask, TaskState}, }; +#[cfg(feature = "watchdog")] +use {kspin::SpinNoIrq, alloc::vec::Vec, crate::WeakAxTaskRef}; + macro_rules! percpu_static { ($( $(#[$comment:meta])* @@ -45,6 +48,18 @@ percpu_static! { PREV_TASK: Weak = Weak::new(), } +/// Stores all tasks for each CPU except those in the 'exited' state. +#[cfg(feature = "watchdog")] +static mut GLOBAL_TASK_QUEUES: [SpinNoIrq>; axconfig::plat::CPU_NUM] = + [const { SpinNoIrq::new(Vec::new()) }; axconfig::plat::CPU_NUM]; + +/// Returns a mutable reference to the global task queue of the given CPU. +#[cfg(feature = "watchdog")] +#[inline] +pub(crate) fn get_global_task_queue(cpu_id: usize) -> &'static SpinNoIrq> { + unsafe { &GLOBAL_TASK_QUEUES[cpu_id] } +} + /// An array of references to run queues, one for each CPU, indexed by cpu_id. /// /// This static variable holds references to the run queues for each CPU in the system. @@ -243,6 +258,10 @@ impl AxRunQueueRef<'_, G> { self.inner.cpu_id ); assert!(task.is_ready()); + #[cfg(feature = "watchdog")] + get_global_task_queue(self.inner.cpu_id) + .lock() + .push(Arc::downgrade(&task)); self.inner.scheduler.lock().add_task(task); } @@ -363,6 +382,10 @@ impl CurrentRunQueueRef<'_, G> { debug!("task exit: {}, exit_code={}", curr.id_name(), exit_code); assert!(curr.is_running(), "task is not running: {:?}", curr.state()); assert!(!curr.is_idle()); + #[cfg(feature = "watchdog")] + get_global_task_queue(self.inner.cpu_id) + .lock() + .retain(|weak_task| weak_task.upgrade().map_or(true, |t| t.id() != curr.id())); if curr.is_init() { // Safety: it is called from `current_run_queue::().exit_current(exit_code)`, // which disabled IRQs and preemption. @@ -438,8 +461,32 @@ impl AxRunQueue { // gc task should be pinned to the current CPU. gc_task.set_cpumask(AxCpuMask::one_shot(cpu_id)); + #[cfg(feature = "watchdog")] + get_global_task_queue(cpu_id) + .lock() + .push(Arc::downgrade(&gc_task)); + let mut scheduler = Scheduler::new(); scheduler.add_task(gc_task); + + #[cfg(feature = "watchdog")] + { + let watchdog_task = TaskInner::new( + move || loop { + axwatchdog::touch_softlockup(cpu_id, axhal::time::monotonic_time_nanos()); + crate::yield_now(); + }, + "watchdog".into(), + axconfig::TASK_STACK_SIZE, + ) + .into_arc(); + watchdog_task.set_cpumask(AxCpuMask::one_shot(cpu_id)); + get_global_task_queue(cpu_id) + .lock() + .push(Arc::downgrade(&watchdog_task)); + scheduler.add_task(watchdog_task); + } + Self { cpu_id, scheduler: SpinRaw::new(scheduler), diff --git a/scripts/make/cargo.mk b/scripts/make/cargo.mk index e4e96684ed..7e39985c90 100644 --- a/scripts/make/cargo.mk +++ b/scripts/make/cargo.mk @@ -17,7 +17,6 @@ build_args := \ $(build_args-$(MODE)) \ $(verbose) -RUSTFLAGS := -A unsafe_op_in_unsafe_fn RUSTFLAGS_LINK_ARGS := -C link-arg=-T$(LD_SCRIPT) -C link-arg=-no-pie -C link-arg=-znostart-stop-gc RUSTDOCFLAGS := -Z unstable-options --enable-index-page -D rustdoc::broken_intra_doc_links diff --git a/ulib/axstd/Cargo.toml b/ulib/axstd/Cargo.toml index 1cc5fafb02..6ab9e122e1 100644 --- a/ulib/axstd/Cargo.toml +++ b/ulib/axstd/Cargo.toml @@ -57,9 +57,9 @@ sched-cfs = ["axfeat/sched-cfs"] # File system fs = ["arceos_api/fs", "axfeat/fs"] -myfs = ["arceos_api/myfs", "axfeat/myfs"] -ext4fs = ["axfeat/ext4fs"] -fatfs = ["axfeat/fatfs"] +fs-ext4 = ["axfeat/fs-ext4"] +fs-fat = ["axfeat/fs-fat"] +fs-times = ["axfeat/fs-times"] # Networking net = ["arceos_api/net", "axfeat/net"] @@ -84,6 +84,7 @@ driver-ixgbe = ["axfeat/driver-ixgbe"] driver-fxmac = ["axfeat/driver-fxmac"] driver-bcm2835-sdhci = ["axfeat/driver-bcm2835-sdhci"] driver-dyn = ["axfeat/driver-dyn"] +driver-ahci = ["axfeat/driver-ahci"] # Backtrace dwarf = ["axfeat/dwarf"]