diff --git a/.github/workflows/go_test.yml b/.github/workflows/go_test.yml new file mode 100644 index 0000000..40ca116 --- /dev/null +++ b/.github/workflows/go_test.yml @@ -0,0 +1,26 @@ +name: Go Test CI + +on: + push: + branches: [ main, master ] # Adjust if your main branch has a different name + pull_request: + branches: [ main, master ] # Adjust if your main branch has a different name + +jobs: + test: + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + go-version: ['1.23.x'] # Using a recent stable Go version + runs-on: ${{ matrix.os }} + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go-version }} + + - name: Run tests + run: go test -v ./... diff --git a/serialfinder_darwin.go b/serialfinder_darwin.go index 47ac972..ccee41c 100644 --- a/serialfinder_darwin.go +++ b/serialfinder_darwin.go @@ -7,118 +7,164 @@ import ( "bufio" "bytes" "fmt" - "os/exec" + "os/exec" // Keep this for the default executor "regexp" "strconv" "strings" ) -// GetSerialDevices retrieves USB serial devices on macOS by querying the I/O Registry, -// filtering by VID and PID, and finding the corresponding device path. -func GetSerialDevices(vid, pid string) ([]SerialDeviceInfo, error) { - var devices []SerialDeviceInfo +// commandExecutor defines an interface for executing external commands. +// This allows for mocking exec.Command in tests. +type commandExecutor interface { + Execute(name string, arg ...string) ([]byte, error) +} + +// defaultExecutor is the default implementation of commandExecutor using exec.Command. +type defaultExecutor struct{} + +func (de *defaultExecutor) Execute(name string, arg ...string) ([]byte, error) { + cmd := exec.Command(name, arg...) + var stdout bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr - // Use ioreg to get device information in a parseable format - // -c IOSerialBSDClient: Focus on serial port client drivers - // -r: Recursive search up the device tree to find parent USB devices - // -l: Show properties for each device - cmd := exec.Command("ioreg", "-r", "-c", "IOSerialBSDClient", "-l") - var out bytes.Buffer - cmd.Stdout = &out err := cmd.Run() if err != nil { - // Handle case where ioreg might fail or return non-zero if no devices found - // Check stderr? For now, assume error means failure or no devices. - // An empty output might just mean no serial devices connected. - if out.Len() == 0 { - // No output probably means no serial devices, not necessarily an error - return devices, nil + // Include stderr in the error message if available for better debugging. + if stderr.Len() > 0 { + return stdout.Bytes(), fmt.Errorf("command %s %v failed with error: %v, stderr: %s", name, strings.Join(arg, " "), err, stderr.String()) } - return nil, fmt.Errorf("failed to run ioreg: %v, output: %s", err, out.String()) + return stdout.Bytes(), fmt.Errorf("command %s %v failed with error: %v", name, strings.Join(arg, " "), err) + } + return stdout.Bytes(), nil +} + +// GetSerialDevices is the public function to retrieve USB serial devices on macOS. +// It uses the default command executor. +func GetSerialDevices(vid, pid string) ([]SerialDeviceInfo, error) { + return getSerialDevicesWithExecutor(vid, pid, &defaultExecutor{}) +} + +// getSerialDevicesWithExecutor is the internal implementation that allows using a custom commandExecutor. +// This is used for testing. +func getSerialDevicesWithExecutor(vid, pid string, executor commandExecutor) ([]SerialDeviceInfo, error) { + var devices []SerialDeviceInfo + + // Use ioreg to get device information. + ioregOutput, err := executor.Execute("ioreg", "-r", "-c", "IOSerialBSDClient", "-l") + if err != nil { + // If the command itself failed, this is an error. + // The executor.Execute should ideally include command output if err is not nil but output exists. + // Based on current defaultExecutor, ioregOutput might contain partial stdout on error. + // We wrap the error from the executor. + // If ioregOutput is also empty, it might indicate no devices OR a more fundamental issue. + // The error message from defaultExecutor already includes stderr. + return nil, fmt.Errorf("failed to execute ioreg: %w", err) + } + + // If ioreg ran successfully but produced no output, it means no serial devices were found. + if len(ioregOutput) == 0 { + return devices, nil } // Prepare VID/PID for case-insensitive comparison targetVidUpper := strings.ToUpper(vid) targetPidUpper := strings.ToUpper(pid) - scanner := bufio.NewScanner(&out) - var currentDevice *SerialDeviceInfo - var inUSBDeviceBlock bool // Flag to track if we are inside a relevant USB device entry + scanner := bufio.NewScanner(bytes.NewReader(ioregOutput)) + // currentUSBDevice holds properties of the most recently encountered USB device. + // We assume that an IOSerialBSDClient's properties will follow its parent USB device's properties. + var currentUSBDevice *SerialDeviceInfo - // Regex to extract key-value pairs like "key" = value - // Handles strings ("value"), numbers (123), hex numbers (0x123) + // Regex to extract key-value pairs: "key" = value reKeyValue := regexp.MustCompile(`"([^"]+)"\s*=\s*(.*)`) for scanner.Scan() { - line := scanner.Text() - - // Check if we are entering a new device potentially containing USB info - // Reset state if we leave an indented block associated with a potential USB parent - // This parsing logic is simplified; a full tree parser would be more robust. - // We primarily look for IOUSBHostDevice or IOUSBDevice containing VID/PID/Serial, - // and then find the child IOSerialBSDClient for the port. - if strings.Contains(line, " + // Or for the serial client: +-o IOSerialBSDClient + if strings.HasPrefix(line, "+-o") { + if strings.Contains(line, "IOUSBDevice") || strings.Contains(line, "IOUSBHostDevice") { + // New USB device encountered, reset currentUSBDevice + currentUSBDevice = &SerialDeviceInfo{} + } else if !strings.Contains(line, "IOSerialBSDClient") { + // If it's another type of device, and not the serial client itself, + // we might have left the scope of the current USB device. + // This is a heuristic: if an unrelated device appears, the previous USB context is likely no longer relevant + // for any subsequent IOSerialBSDClient unless a new USB device is explicitly listed. + currentUSBDevice = nil } + // If it's an IOSerialBSDClient line, we don't reset currentUSBDevice here, + // as the following lines will contain its properties, and we need the context + // of the *parent* USB device. } - if currentDevice != nil { - match := reKeyValue.FindStringSubmatch(strings.TrimSpace(line)) - if len(match) == 3 { - key := match[1] - value := strings.TrimSpace(match[2]) - - // Extract VID, PID, SerialNumber from the USB device block - if inUSBDeviceBlock { - switch key { - case "idVendor": - hexVal, err := parseHexValue(value) - if err == nil { - currentDevice.Vid = fmt.Sprintf("%04X", hexVal) - } - case "idProduct": - hexVal, err := parseHexValue(value) - if err == nil { - currentDevice.Pid = fmt.Sprintf("%04X", hexVal) - } - case "USB Serial Number": // Note: Key name can vary slightly (sometimes kUSBSerialNumberString) - currentDevice.SerialNumber = parseStringValue(value) - case "kUSBSerialNumberString": // Alternative key name - if currentDevice.SerialNumber == "" { // Prefer "USB Serial Number" if available - currentDevice.SerialNumber = parseStringValue(value) - } + match := reKeyValue.FindStringSubmatch(line) + if len(match) == 3 { + key := match[1] + value := strings.TrimSpace(match[2]) + + // Populate properties for the current USB device context + if currentUSBDevice != nil { + switch key { + case "idVendor": + hexVal, err := parseHexValue(value) + if err == nil { + currentUSBDevice.Vid = fmt.Sprintf("%04X", hexVal) + } + case "idProduct": + hexVal, err := parseHexValue(value) + if err == nil { + currentUSBDevice.Pid = fmt.Sprintf("%04X", hexVal) + } + // USB Product Name and Serial Number can also be extracted if needed, + // but are not strictly part of SerialDeviceInfo struct currently. + case "USB Serial Number", "kUSBSerialNumberString": + // Favor "USB Serial Number" but take kUSBSerialNumberString if the other is not present or empty. + // The check `currentUSBDevice.SerialNumber == ""` handles this implicitly if "USB Serial Number" comes first. + sn := parseStringValue(value) + if sn != "" { // Only overwrite if we get a non-empty serial number + currentUSBDevice.SerialNumber = sn } } + } - // Extract Port from the IOSerialBSDClient block (which is a child) - if key == "IOCalloutDevice" { - // This property belongs to the IOSerialBSDClient, which should be listed *after* - // its parent USB device properties in the `ioreg -r` output. + // Check for IOCalloutDevice, which indicates the serial port path. + // This property is part of the IOSerialBSDClient. + if key == "IOCalloutDevice" { + // We expect currentUSBDevice to be populated from the parent USB device + // that appeared earlier in the ioreg output. + if currentUSBDevice != nil && currentUSBDevice.Vid != "" && currentUSBDevice.Pid != "" { portPath := parseStringValue(value) - if portPath != "" && currentDevice.Vid != "" && currentDevice.Pid != "" { - currentDevice.Port = portPath - - // Check if VID/PID match the filter (if provided) - vidMatch := (targetVidUpper == "" || currentDevice.Vid == targetVidUpper) - pidMatch := (targetPidUpper == "" || currentDevice.Pid == targetPidUpper) + if portPath != "" { + // We have a potential serial device. Check against VID/PID filters. + // currentUSBDevice.Vid and currentUSBDevice.Pid are already uppercase from fmt.Sprintf("%04X"). + vidMatch := (targetVidUpper == "" || currentUSBDevice.Vid == targetVidUpper) + pidMatch := (targetPidUpper == "" || currentUSBDevice.Pid == targetPidUpper) if vidMatch && pidMatch { - // Found a matching device, add a copy to the list - devices = append(devices, *currentDevice) + // Create a new SerialDeviceInfo for the list, copying relevant USB properties. + device := SerialDeviceInfo{ + Port: portPath, + Vid: currentUSBDevice.Vid, + Pid: currentUSBDevice.Pid, + SerialNumber: currentUSBDevice.SerialNumber, + // Description could be added here if parsed, e.g., from "USB Product Name" + } + devices = append(devices, device) } - // Reset for the next potential device block found by ioreg - // Since IOCalloutDevice is usually the last relevant piece, reset here. - currentDevice = nil - inUSBDeviceBlock = false } } + // After processing an IOCalloutDevice, the properties of currentUSBDevice have been used + // or deemed irrelevant. It's not strictly necessary to reset currentUSBDevice here, + // as a new "+-o IOUSB..." line will do that. However, if multiple IOSerialBSDClient + // entries were nested under one IOUSBDevice (uncommon for distinct physical ports), + // not resetting could lead to issues. For typical scenarios, this is okay. + // For now, let the next "+-o IOUSB..." line handle the reset of currentUSBDevice. } } } @@ -130,30 +176,19 @@ func GetSerialDevices(vid, pid string) ([]SerialDeviceInfo, error) { return devices, nil } -// parseHexValue converts ioreg number values (like 0x1234 or 1234) to int64 +// parseHexValue converts ioreg number values to int64. +// ioreg typically outputs VID/PID as decimal numbers, but can also use "0x" prefix for hex. func parseHexValue(value string) (int64, error) { value = strings.TrimSpace(value) - // Remove trailing comma if present (sometimes happens in ioreg output) - value = strings.TrimSuffix(value, ",") + value = strings.TrimSuffix(value, ",") // Remove trailing comma - // Check if it's already a decimal number - decVal, errDec := strconv.ParseInt(value, 10, 64) - if errDec == nil { - return decVal, nil - } - - // Try parsing as hex (ioreg usually uses 0x prefix, but let's be flexible) - if strings.HasPrefix(value, "0x") { + if strings.HasPrefix(value, "0x") || strings.HasPrefix(value, "0X") { + // Explicitly hex if "0x" prefix is present return strconv.ParseInt(value[2:], 16, 64) } - // Fallback attempt if no prefix but maybe hex? Unlikely needed for VID/PID. - hexVal, errHex := strconv.ParseInt(value, 16, 64) - if errHex == nil { - return hexVal, nil - } - - // Return the original decimal error if hex also failed - return 0, errDec + // Otherwise, assume it's a decimal number (standard for ioreg idVendor/idProduct) + // If it's not a valid decimal, this will return an error. + return strconv.ParseInt(value, 10, 64) } // parseStringValue extracts string values like "My String" -> My String @@ -161,9 +196,9 @@ func parseStringValue(value string) string { value = strings.TrimSpace(value) // Remove trailing comma if present value = strings.TrimSuffix(value, ",") - // Remove surrounding quotes - if strings.HasPrefix(value, `"`) && strings.HasSuffix(value, `"`) { + // Remove surrounding quotes only if the string is long enough to contain them + if len(value) >= 2 && strings.HasPrefix(value, `"`) && strings.HasSuffix(value, `"`) { return value[1 : len(value)-1] } - return value // Return as-is if not quoted + return value // Return as-is if not properly quoted or too short } diff --git a/serialfinder_darwin_test.go b/serialfinder_darwin_test.go new file mode 100644 index 0000000..991c0d2 --- /dev/null +++ b/serialfinder_darwin_test.go @@ -0,0 +1,445 @@ +//go:build darwin +// +build darwin + +package serialfinder + +import ( + "errors" + "reflect" + "strings" + "testing" +) + +// Mock commandExecutor for testing +type mockExecutor struct { + Output []byte + Err error + CalledName string + CalledArgs []string +} + +func (me *mockExecutor) Execute(name string, arg ...string) ([]byte, error) { + me.CalledName = name + me.CalledArgs = arg + return me.Output, me.Err +} + +func TestParseHexValue(t *testing.T) { + t.Helper() + tests := []struct { + name string + input string + want int64 + wantErr bool + }{ + {"valid decimal", "1234", 1234, false}, + {"valid hex with 0x", "0x4D2", 1234, false}, + {"valid hex with 0X", "0X4d2", 1234, false}, + {"valid decimal with comma", "1234,", 1234, false}, + {"valid hex with comma", "0x4D2,", 1234, false}, + {"zero decimal", "0", 0, false}, + {"zero hex", "0x0", 0, false}, + {"large decimal", "1234567890", 1234567890, false}, + {"large hex", "0xABCDEF12", 0xABCDEF12, false}, + {"invalid input - letters", "abc", 0, true}, + {"invalid input - hex letters no prefix", "ABC", 0, true}, // parseHexValue expects decimal if no 0x + {"empty input", "", 0, true}, + {"only 0x", "0x", 0, true}, + {"hex with invalid chars", "0xGHI", 0, true}, + {"decimal with space", "12 34", 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseHexValue(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("parseHexValue(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.want { + t.Errorf("parseHexValue(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestParseStringValue(t *testing.T) { + t.Helper() + tests := []struct { + name string + input string + want string + }{ + {"quoted string", `"Hello, World!"`, "Hello, World!"}, + {"unquoted string", `MyDevice`, "MyDevice"}, + {"quoted string with comma", `"Test,"`, "Test"}, + {"unquoted string with comma", `Test,`, "Test"}, // TrimSuffix will remove it + {"empty quoted string", `""`, ""}, + {"empty unquoted string", ``, ""}, + {"string with internal spaces", `"Spaces In Side"`, "Spaces In Side"}, + {"string with leading/trailing spaces in quotes", `" Spaced "`, " Spaced "}, + {"string with only spaces in quotes", `" "`, " "}, + {"already trimmed string", `NoQuotes`, "NoQuotes"}, + {"string is just a quote", `"`, `"`}, // Does not have prefix and suffix of quote + {"string is two quotes", `""`, ``}, // Has prefix and suffix + {"string is three quotes", `"""`, `"`}, // Strips first and last + {"string with internal quotes", `"a"b"c"`, `a"b"c`}, // Strips first and last + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseStringValue(tt.input); got != tt.want { + t.Errorf("parseStringValue(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +const mockIoregOutputEmpty = ` ++-o Root +{ +} +` + +const mockIoregOutputSingleDevice = ` ++-o Root + +-o IOUSBHostDevice + { + "sessionID" = 1112309454 + "idProduct" = 22332 // PID: 0x573C + "idVendor" = 1155 // VID: 0x0483 + "kUSBSerialNumberString" = "SERIAL123" + "USB Product Name" = "Test USB Device" + } + +-o AppleUSBHostCompositeDevice + +-o AppleUSBHostInterface@0 + +-o IOSerialBSDClient + { + "IOTTYBaseName" = "usbmodem" + "IOCalloutDevice" = "/dev/cu.usbmodemSERIAL1231" + "IODialinDevice" = "/dev/tty.usbmodemSERIAL1231" + "IOTTYDevice" = "usbmodemSERIAL1231" + "idProduct" = 22332 + "idVendor" = 1155 + } +` +const mockIoregOutputSingleDeviceFTDI = ` ++-o Root + +-o IOUSBDevice + { + "sessionID" = 1112309454 + "idProduct" = 24577 // PID: 0x6001 + "idVendor" = 1027 // VID: 0x0403 + "USB Serial Number" = "FTDI_SERIAL" + } + +-o AppleUSBInterface@0 + +-o IOSerialBSDClient + { + "IOCalloutDevice" = "/dev/cu.usbserial-FTDI_SERIAL" + } +` + +const mockIoregOutputTwoDevices = mockIoregOutputSingleDevice + ` + +-o IOUSBHostDevice + { + "idProduct" = 8193 // PID: 0x2001 + "idVendor" = 4292 // VID: 0x10C4 + "kUSBSerialNumberString" = "SERIAL_XYZ" + } + +-o AppleUSBHostCompositeDevice + +-o AppleUSBHostInterface@0 + +-o IOSerialBSDClient + { + "IOCalloutDevice" = "/dev/cu.usbmodemSERIAL_XYZ1" + } +` +const mockIoregOutputMissingSerial = ` ++-o Root + +-o IOUSBHostDevice + { + "idProduct" = 22332 + "idVendor" = 1155 + // No Serial Number + } + +-o AppleUSBHostCompositeDevice + +-o AppleUSBHostInterface@0 + +-o IOSerialBSDClient + { + "IOCalloutDevice" = "/dev/cu.usbmodemNOSERIAL1" + } +` + +const mockIoregOutputMissingVID = ` ++-o Root + +-o IOUSBHostDevice + { + "idProduct" = 22332 + // No idVendor + "kUSBSerialNumberString" = "SERIAL_NOVID" + } + +-o AppleUSBHostCompositeDevice + +-o AppleUSBHostInterface@0 + +-o IOSerialBSDClient + { + "IOCalloutDevice" = "/dev/cu.usbmodemSERIAL_NOVID1" + } +` +const mockIoregOutputMissingPID = ` ++-o Root + +-o IOUSBHostDevice + { + "idVendor" = 1155 + // No idProduct + "kUSBSerialNumberString" = "SERIAL_NOPID" + } + +-o AppleUSBHostCompositeDevice + +-o AppleUSBHostInterface@0 + +-o IOSerialBSDClient + { + "IOCalloutDevice" = "/dev/cu.usbmodemSERIAL_NOPID1" + } +` +const mockIoregOutputMissingPort = ` ++-o Root + +-o IOUSBHostDevice + { + "idProduct" = 22332 + "idVendor" = 1155 + "kUSBSerialNumberString" = "SERIAL_NOPORT" + } + +-o AppleUSBHostCompositeDevice + +-o AppleUSBHostInterface@0 + +-o IOSerialBSDClient + { + // No IOCalloutDevice + } +` + +func TestGetSerialDevicesWithExecutor_NoDevices(t *testing.T) { + t.Helper() + executor := &mockExecutor{Output: []byte(mockIoregOutputEmpty)} + devices, err := getSerialDevicesWithExecutor("", "", executor) + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 0 { + t.Fatalf("expected 0 devices, got %d", len(devices)) + } +} + +func TestGetSerialDevicesWithExecutor_IoregError(t *testing.T) { + t.Helper() + expectedErr := errors.New("ioreg command failed") + executor := &mockExecutor{Err: expectedErr} + _, err := getSerialDevicesWithExecutor("", "", executor) + if err == nil { + t.Fatalf("expected an error, but got nil") + } + if !strings.Contains(err.Error(), expectedErr.Error()) { + t.Errorf("expected error string '%v' to contain '%v'", err, expectedErr) + } +} + +func TestGetSerialDevicesWithExecutor_SingleDevice_NoFilter(t *testing.T) { + t.Helper() + executor := &mockExecutor{Output: []byte(mockIoregOutputSingleDevice)} + devices, err := getSerialDevicesWithExecutor("", "", executor) + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 1 { + t.Fatalf("expected 1 device, got %d", len(devices)) + } + expected := SerialDeviceInfo{ + Vid: "0483", // 1155 + Pid: "573C", // 22332 + SerialNumber: "SERIAL123", + Port: "/dev/cu.usbmodemSERIAL1231", + } + if !reflect.DeepEqual(devices[0], expected) { + t.Errorf("device info mismatch:\ngot %+v\nwant %+v", devices[0], expected) + } +} + +func TestGetSerialDevicesWithExecutor_SingleDevice_FTDI(t *testing.T) { + t.Helper() + executor := &mockExecutor{Output: []byte(mockIoregOutputSingleDeviceFTDI)} + devices, err := getSerialDevicesWithExecutor("0403", "6001", executor) // VID: 0x0403, PID: 0x6001 + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 1 { + t.Fatalf("expected 1 device, got %d: %+v", len(devices), devices) + } + expected := SerialDeviceInfo{ + Vid: "0403", + Pid: "6001", + SerialNumber: "FTDI_SERIAL", + Port: "/dev/cu.usbserial-FTDI_SERIAL", + } + if !reflect.DeepEqual(devices[0], expected) { + t.Errorf("device info mismatch:\ngot %+v\nwant %+v", devices[0], expected) + } +} + + +func TestGetSerialDevicesWithExecutor_SingleDevice_WithVIDPIDFilterMatch(t *testing.T) { + t.Helper() + executor := &mockExecutor{Output: []byte(mockIoregOutputSingleDevice)} + // VID: 0x0483, PID: 0x573C + devices, err := getSerialDevicesWithExecutor("0483", "573C", executor) + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 1 { + t.Fatalf("expected 1 device, got %d", len(devices)) + } +} + +func TestGetSerialDevicesWithExecutor_SingleDevice_WithVIDFilterMismatch(t *testing.T) { + t.Helper() + executor := &mockExecutor{Output: []byte(mockIoregOutputSingleDevice)} + devices, err := getSerialDevicesWithExecutor("FFFF", "573C", executor) // Mismatched VID + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 0 { + t.Fatalf("expected 0 devices due to VID mismatch, got %d", len(devices)) + } +} + +func TestGetSerialDevicesWithExecutor_SingleDevice_WithPIDFilterMismatch(t *testing.T) { + t.Helper() + executor := &mockExecutor{Output: []byte(mockIoregOutputSingleDevice)} + devices, err := getSerialDevicesWithExecutor("0483", "FFFF", executor) // Mismatched PID + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 0 { + t.Fatalf("expected 0 devices due to PID mismatch, got %d", len(devices)) + } +} + +func TestGetSerialDevicesWithExecutor_VIDPIDCaseInsensitiveFilter(t *testing.T) { + t.Helper() + executor := &mockExecutor{Output: []byte(mockIoregOutputSingleDevice)} + // VID: 0x0483, PID: 0x573C. Device stores them as "0483", "573C". + // Test with lowercase filter. + devices, err := getSerialDevicesWithExecutor("0x483", "0x573c", executor) + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 1 { + t.Errorf("expected 1 device with case-insensitive filter, got %d", len(devices)) + } +} + +func TestGetSerialDevicesWithExecutor_MultipleDevices(t *testing.T) { + t.Helper() + executor := &mockExecutor{Output: []byte(mockIoregOutputTwoDevices)} + devices, err := getSerialDevicesWithExecutor("", "", executor) + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 2 { + t.Fatalf("expected 2 devices, got %d", len(devices)) + } + + expected1 := SerialDeviceInfo{ + Vid: "0483", // 1155 + Pid: "573C", // 22332 + SerialNumber: "SERIAL123", + Port: "/dev/cu.usbmodemSERIAL1231", + } + expected2 := SerialDeviceInfo{ + Vid: "10C4", // 4292 + Pid: "2001", // 8193 + SerialNumber: "SERIAL_XYZ", + Port: "/dev/cu.usbmodemSERIAL_XYZ1", + } + + // Check if both expected devices are present, order might vary + found1 := false + found2 := false + for _, device := range devices { + if reflect.DeepEqual(device, expected1) { + found1 = true + } + if reflect.DeepEqual(device, expected2) { + found2 = true + } + } + if !found1 || !found2 { + t.Errorf("did not find all expected devices.\nGot: %+v\nExpected to find: %+v and %+v", devices, expected1, expected2) + } + + // Test with filter matching one + devices, err = getSerialDevicesWithExecutor("10C4", "2001", executor) + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 1 { + t.Fatalf("expected 1 device with filter, got %d", len(devices)) + } + if !reflect.DeepEqual(devices[0], expected2) { + t.Errorf("device info mismatch with filter:\ngot %+v\nwant %+v", devices[0], expected2) + } +} + + +func TestGetSerialDevicesWithExecutor_DeviceWithMissingSerialNumber(t *testing.T) { + t.Helper() + executor := &mockExecutor{Output: []byte(mockIoregOutputMissingSerial)} + devices, err := getSerialDevicesWithExecutor("0483", "573C", executor) // VID: 1155->0483, PID: 22332->573C + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 1 { + t.Fatalf("expected 1 device even with missing serial, got %d", len(devices)) + } + if devices[0].SerialNumber != "" { + t.Errorf("expected empty SerialNumber, got %s", devices[0].SerialNumber) + } + if devices[0].Port != "/dev/cu.usbmodemNOSERIAL1" { + t.Errorf("expected Port '/dev/cu.usbmodemNOSERIAL1', got %s", devices[0].Port) + } +} + +func TestGetSerialDevicesWithExecutor_DeviceWithMissingVID(t *testing.T) { + t.Helper() + // The parser skips devices if VID/PID cannot be determined from the USB device block. + executor := &mockExecutor{Output: []byte(mockIoregOutputMissingVID)} + devices, err := getSerialDevicesWithExecutor("", "", executor) + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 0 { + t.Fatalf("expected 0 devices when VID is missing from USB block, got %d: %+v", len(devices), devices) + } +} + +func TestGetSerialDevicesWithExecutor_DeviceWithMissingPID(t *testing.T) { + t.Helper() + // The parser skips devices if VID/PID cannot be determined from the USB device block. + executor := &mockExecutor{Output: []byte(mockIoregOutputMissingPID)} + devices, err := getSerialDevicesWithExecutor("", "", executor) + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 0 { + t.Fatalf("expected 0 devices when PID is missing from USB block, got %d: %+v", len(devices), devices) + } +} + +func TestGetSerialDevicesWithExecutor_DeviceWithMissingPort(t *testing.T) { + t.Helper() + // If IOCalloutDevice is missing, the device won't be added. + executor := &mockExecutor{Output: []byte(mockIoregOutputMissingPort)} + devices, err := getSerialDevicesWithExecutor("0483", "573C", executor) + if err != nil { + t.Fatalf("getSerialDevicesWithExecutor returned error: %v", err) + } + if len(devices) != 0 { + t.Fatalf("expected 0 devices when IOCalloutDevice is missing, got %d: %+v", len(devices), devices) + } +} diff --git a/serialfinder_linux.go b/serialfinder_linux.go index 4e3ed1a..904c5d9 100644 --- a/serialfinder_linux.go +++ b/serialfinder_linux.go @@ -5,24 +5,66 @@ package serialfinder import ( "fmt" + "io/fs" // For fs.FileMode "os" "path/filepath" "strings" ) -// GetSerialDevices retrieves USB devices on Linux by searching the `/dev/serial/by-id` directory, filtering by VID and PID, and finding the corresponding port +// fileSystemReader defines an interface for filesystem operations. +// This allows for mocking the filesystem in tests. +type fileSystemReader interface { + ReadDir(dirname string) ([]os.DirEntry, error) + EvalSymlinks(path string) (string, error) + ReadFile(filename string) ([]byte, error) + Stat(name string) (os.FileInfo, error) // For checkForVIDPIDFiles +} + +// defaultFileSystemReader is the default implementation of fileSystemReader using os and filepath. +type defaultFileSystemReader struct{} + +func (r *defaultFileSystemReader) ReadDir(dirname string) ([]os.DirEntry, error) { + return os.ReadDir(dirname) +} +func (r *defaultFileSystemReader) EvalSymlinks(path string) (string, error) { + return filepath.EvalSymlinks(path) +} +func (r *defaultFileSystemReader) ReadFile(filename string) ([]byte, error) { + return os.ReadFile(filename) +} +func (r *defaultFileSystemReader) Stat(name string) (os.FileInfo, error) { + return os.Stat(name) +} + +// GetSerialDevices is the public function to retrieve USB devices on Linux. +// It uses the default file system reader. func GetSerialDevices(vid, pid string) ([]SerialDeviceInfo, error) { + return getSerialDevicesWithReader(vid, pid, &defaultFileSystemReader{}) +} + +// getSerialDevicesWithReader is the internal implementation that allows using a custom fileSystemReader. +// This is used for testing. +func getSerialDevicesWithReader(vid, pid string, reader fileSystemReader) ([]SerialDeviceInfo, error) { var devices []SerialDeviceInfo // Path to the serial devices by ID directory serialByIDPath := "/dev/serial/by-id" // Read all the symlinks in the directory - entries, err := os.ReadDir(serialByIDPath) + entries, err := reader.ReadDir(serialByIDPath) if err != nil { - return nil, err + // If /dev/serial/by-id doesn't exist or is unreadable, it might mean no devices or a permission issue. + // This is a common scenario if no relevant udev rules created these symlinks. + if os.IsNotExist(err) || os.IsPermission(err) { + return devices, nil // Return empty list, not an error + } + return nil, fmt.Errorf("error reading %s: %w", serialByIDPath, err) } + // Prepare VID/PID for case-insensitive comparison + targetVidUpper := strings.ToUpper(vid) + targetPidUpper := strings.ToUpper(pid) + // Iterate over each entry in the directory for _, entry := range entries { if entry.IsDir() { @@ -33,89 +75,120 @@ func GetSerialDevices(vid, pid string) ([]SerialDeviceInfo, error) { symlinkPath := filepath.Join(serialByIDPath, entry.Name()) // Resolve the symbolic link to get the actual device path - devicePath, err := filepath.EvalSymlinks(symlinkPath) + devicePath, err := reader.EvalSymlinks(symlinkPath) if err != nil { + // Could be a broken symlink, skip it. continue } // Find the USB device directory associated with this tty device - usbDir := findSerialDeviceInfoDir(devicePath) + usbDir := findSerialDeviceInfoDirWithReader(devicePath, reader) if usbDir == "" { continue } // Read the VID and PID - idVendor, err := os.ReadFile(filepath.Join(usbDir, "idVendor")) + idVendorBytes, err := reader.ReadFile(filepath.Join(usbDir, "idVendor")) if err != nil { - fmt.Printf("Error reading idVendor: %v\n", err) - continue + // If we can't read VID, this device is problematic. + // Depending on desired strictness, could continue or return error. + // For now, let's be strict as VID/PID are crucial. + return nil, fmt.Errorf("error reading idVendor for %s (from %s): %w", usbDir, symlinkPath, err) } + idVendor := idVendorBytes - idProduct, err := os.ReadFile(filepath.Join(usbDir, "idProduct")) + idProductBytes, err := reader.ReadFile(filepath.Join(usbDir, "idProduct")) if err != nil { - fmt.Printf("Error reading idProduct: %v\n", err) - continue + return nil, fmt.Errorf("error reading idProduct for %s (from %s): %w", usbDir, symlinkPath, err) } + idProduct := idProductBytes - // Log the VID and PID for debugging vidStr := strings.ToUpper(strings.TrimSpace(string(idVendor))) pidStr := strings.ToUpper(strings.TrimSpace(string(idProduct))) - // Check if the VID and PID match the specified values - if vidStr != "" && vidStr != vid { + // Filter by VID if a VID is provided + if targetVidUpper != "" && vidStr != targetVidUpper { continue } - if pidStr != "" && pidStr != pid { + // Filter by PID if a PID is provided + if targetPidUpper != "" && pidStr != targetPidUpper { continue } // Read the serial number - serialNumber, err := os.ReadFile(filepath.Join(usbDir, "serial")) + var serialNumberStr string + serialNumberBytes, err := reader.ReadFile(filepath.Join(usbDir, "serial")) if err != nil { - fmt.Printf("Error reading serial: %v\n", err) - serialNumber = []byte("") + // Non-critical if serial is missing, proceed with an empty serial number. + serialNumberStr = "" + } else { + serialNumberStr = strings.TrimSpace(string(serialNumberBytes)) } // Add the device to the list + // Port is the stable /dev/serial/by-id path, which is useful for persistent device naming. devices = append(devices, SerialDeviceInfo{ - SerialNumber: strings.TrimSpace(string(serialNumber)), + SerialNumber: serialNumberStr, Vid: vidStr, Pid: pidStr, - Port: symlinkPath, + Port: symlinkPath, // symlinkPath is e.g., /dev/serial/by-id/usb-MyDevice_Serial-if00-port0 }) } return devices, nil } -// findSerialDeviceInfoDir returns the directory path of the USB device corresponding to the device path -func findSerialDeviceInfoDir(devicePath string) string { +// findSerialDeviceInfoDirWithReader is the testable version of findSerialDeviceInfoDir. +func findSerialDeviceInfoDirWithReader(devicePath string, reader fileSystemReader) string { // Get the full path to the tty device in /sys/class/tty + // devicePath is something like /dev/ttyUSB0 or /dev/ttyACM0 + // We need its base name, e.g., ttyUSB0 sysTTYPath := filepath.Join("/sys/class/tty", filepath.Base(devicePath), "device") - // Follow the symlink to the actual device directory - usbDir, err := filepath.EvalSymlinks(sysTTYPath) + // Follow the symlink to the actual device directory in sysfs (e.g., /sys/devices/pci0000:00/0000:00:14.0/usb1/1-1/1-1:1.0) + usbDeviceSysfsPath, err := reader.EvalSymlinks(sysTTYPath) if err != nil { - return "" + return "" // Cannot resolve path to device's sysfs directory + } + + // The usbDeviceSysfsPath usually points to an interface directory (e.g., /sys/.../1-1:1.0). + // The actual USB device directory (containing idVendor, idProduct) is typically one or two levels up. + // Example: /sys/devices/pci0000:00/0000:00:14.0/usb1/1-1 <-- This is what we want + // /sys/devices/pci0000:00/0000:00:14.0/usb1/1-1/1-1:1.0 + // /sys/devices/pci0000:00/0000:00:14.0/usb1/1-1/1-1:1.1 + // + // Check current directory (usbDeviceSysfsPath could sometimes be the main device dir, though less common for USB serial) + if checkForVIDPIDFilesWithReader(usbDeviceSysfsPath, reader) { + return usbDeviceSysfsPath } - // Navigate up one or two directories to find the actual USB device directory - parentDir := filepath.Dir(usbDir) - if checkForVIDPIDFiles(parentDir) { + // Check parent directory + parentDir := filepath.Dir(usbDeviceSysfsPath) + if checkForVIDPIDFilesWithReader(parentDir, reader) { return parentDir } + // For some devices, it might be two levels up (e.g. if usbDeviceSysfsPath was .../1-1/1-1.0/tty/ttyUSB0) + // but the typical structure for /sys/class/tty/{ttyX}/device -> .../usbX/X-Y/X-Y:Z.A + // has the VID/PID files in .../usbX/X-Y. + // So checking grandparentDir of usbDeviceSysfsPath (which is parentDir of parentDir) grandparentDir := filepath.Dir(parentDir) - if checkForVIDPIDFiles(grandparentDir) { - return grandparentDir + if grandparentDir != parentDir && grandparentDir != "." && grandparentDir != "/" { // Avoid going too high + if checkForVIDPIDFilesWithReader(grandparentDir, reader) { + return grandparentDir + } } + // Further check for cases like /sys/devices/.../usb1/1-1/1-1.0/device/../.. (less direct but possible) + // The current logic for parentDir and grandparentDir should cover most standard cases where + // 'device' symlinks to something like '.../1-1:1.0' and VID/PID are in '.../1-1'. - return "" + return "" // Could not find a directory with idVendor/idProduct files } -// checkForVIDPIDFiles checks if the directory contains idVendor and idProduct files -func checkForVIDPIDFiles(dir string) bool { - _, errVid := os.Stat(filepath.Join(dir, "idVendor")) - _, errPid := os.Stat(filepath.Join(dir, "idProduct")) +// checkForVIDPIDFilesWithReader is the testable version of checkForVIDPIDFiles. +func checkForVIDPIDFilesWithReader(dir string, reader fileSystemReader) bool { + // Check if idVendor and idProduct files exist in the directory + _, errVid := reader.Stat(filepath.Join(dir, "idVendor")) + _, errPid := reader.Stat(filepath.Join(dir, "idProduct")) return errVid == nil && errPid == nil } diff --git a/serialfinder_linux_test.go b/serialfinder_linux_test.go new file mode 100644 index 0000000..7324f14 --- /dev/null +++ b/serialfinder_linux_test.go @@ -0,0 +1,576 @@ +//go:build linux +// +build linux + +package serialfinder + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + "time" +) + +// mockDirEntry implements os.DirEntry for testing. +type mockDirEntry struct { + name string + isDir bool + mode fs.FileMode +} + +func (mde *mockDirEntry) Name() string { return mde.name } +func (mde *mockDirEntry) IsDir() bool { return mde.isDir } +func (mde *mockDirEntry) Type() fs.FileMode { + if mde.mode != 0 { + return mde.mode & fs.ModeType // Return only type bits + } + if mde.isDir { + return fs.ModeDir + } + return 0 // Regular file +} +func (mde *mockDirEntry) Info() (fs.FileInfo, error) { + return &mockFileInfo{name: mde.name, isDir: mde.isDir, mode: mde.mode}, nil +} + +// mockFileInfo implements os.FileInfo for testing. +type mockFileInfo struct { + name string + isDir bool + mode fs.FileMode + modTime time.Time + size int64 +} + +func (mfi *mockFileInfo) Name() string { return mfi.name } +func (mfi *mockFileInfo) Size() int64 { return mfi.size } +func (mfi *mockFileInfo) Mode() fs.FileMode { return mfi.mode } +func (mfi *mockFileInfo) ModTime() time.Time { return mfi.modTime } +func (mfi *mockFileInfo) IsDir() bool { return mfi.isDir } +func (mfi *mockFileInfo) Sys() interface{} { return nil } + +// mockFileSystemReader implements fileSystemReader for testing. +type mockFileSystemReader struct { + mockFiles map[string][]byte + mockDirs map[string][]os.DirEntry + mockSymlinks map[string]string // path -> target + mockStats map[string]os.FileInfo + mockStatErrors map[string]error // path -> error for Stat + + // Specific errors for methods + readDirError error + evalSymlinksError map[string]error // path -> error + readFileError map[string]error // path -> error +} + +func newMockFileSystemReader() *mockFileSystemReader { + return &mockFileSystemReader{ + mockFiles: make(map[string][]byte), + mockDirs: make(map[string][]os.DirEntry), + mockSymlinks: make(map[string]string), + mockStats: make(map[string]os.FileInfo), + mockStatErrors: make(map[string]error), + evalSymlinksError: make(map[string]error), + readFileError: make(map[string]error), + } +} + +func (m *mockFileSystemReader) ReadDir(dirname string) ([]os.DirEntry, error) { + if m.readDirError != nil { + return nil, m.readDirError + } + entries, ok := m.mockDirs[dirname] + if !ok { + return nil, os.ErrNotExist // Default to NotExist if dir not explicitly mocked + } + return entries, nil +} + +func (m *mockFileSystemReader) EvalSymlinks(path string) (string, error) { + if err, ok := m.evalSymlinksError[path]; ok && err != nil { + return "", err + } + target, ok := m.mockSymlinks[path] + if !ok { + // If not a mocked symlink, behave like EvalSymlinks on a regular file/dir + // or return a specific error if it should be a symlink that's missing. + // For simplicity here, if not in map, assume it's not a symlink and return path itself or an error. + // The actual function expects EvalSymlinks to resolve or fail. + return "", os.ErrNotExist // Or return path, "", if it's not necessarily an error for it not to be a symlink + } + return target, nil +} + +func (m *mockFileSystemReader) ReadFile(filename string) ([]byte, error) { + if err, ok := m.readFileError[filename]; ok && err != nil { + return nil, err + } + content, ok := m.mockFiles[filename] + if !ok { + return nil, os.ErrNotExist + } + return content, nil +} + +func (m *mockFileSystemReader) Stat(name string) (os.FileInfo, error) { + if err, ok := m.mockStatErrors[name]; ok && err != nil { + return nil, err + } + info, ok := m.mockStats[name] + if !ok { + return nil, os.ErrNotExist + } + return info, nil +} + +// Helper to add a mock file +func (m *mockFileSystemReader) addFile(path string, content string) { + m.mockFiles[path] = []byte(content) + m.mockStats[path] = &mockFileInfo{name: filepath.Base(path), size: int64(len(content))} +} + +// Helper to add a mock symlink +func (m *mockFileSystemReader) addSymlink(path string, target string) { + m.mockSymlinks[path] = target + // Stat on a symlink usually returns info about the symlink itself + m.mockStats[path] = &mockFileInfo{name: filepath.Base(path), mode: fs.ModeSymlink} +} + +// Helper to add a mock directory entry for ReadDir +func (m *mockFileSystemReader) addDirEntry(dirPath string, entry os.DirEntry) { + m.mockDirs[dirPath] = append(m.mockDirs[dirPath], entry) +} + +// Helper to set a specific error for Stat +func (m *mockFileSystemReader) setStatError(path string, err error) { + m.mockStatErrors[path] = err +} + +// Helper to set a specific error for ReadFile +func (m *mockFileSystemReader) setReadFileError(path string, err error) { + if m.readFileError == nil { + m.readFileError = make(map[string]error) + } + m.readFileError[path] = err +} + +// Helper to set a specific error for EvalSymlinks +func (m *mockFileSystemReader) setEvalSymlinksError(path string, err error) { + if m.evalSymlinksError == nil { + m.evalSymlinksError = make(map[string]error) + } + m.evalSymlinksError[path] = err +} + + +func TestGetSerialDevicesWithReader(t *testing.T) { + t.Helper() + const byIDPath = "/dev/serial/by-id" + + tests := []struct { + name string + vidFilter string + pidFilter string + setupMock func(*mockFileSystemReader) + expected []SerialDeviceInfo + wantErr bool + }{ + { + name: "No devices in by-id path (empty dir)", + setupMock: func(mfs *mockFileSystemReader) { + mfs.mockDirs[byIDPath] = []os.DirEntry{} + }, + expected: []SerialDeviceInfo{}, + }, + { + name: "ReadDir for by-id path returns os.ErrNotExist", + setupMock: func(mfs *mockFileSystemReader) { + mfs.readDirError = os.ErrNotExist + }, + expected: []SerialDeviceInfo{}, // Should return empty, not error + }, + { + name: "ReadDir for by-id path returns other error", + setupMock: func(mfs *mockFileSystemReader) { + mfs.readDirError = errors.New("some ReadDir error") + }, + wantErr: true, + }, + { + name: "Single device, no filter", + setupMock: func(mfs *mockFileSystemReader) { + // /dev/serial/by-id entry + mfs.addDirEntry(byIDPath, &mockDirEntry{name: "usb-MyCorp_MyDevice_SERIAL123-if00-port0", mode: fs.ModeSymlink}) + // Symlink target for by-id entry + byIDSymlinkPath := filepath.Join(byIDPath, "usb-MyCorp_MyDevice_SERIAL123-if00-port0") + mfs.addSymlink(byIDSymlinkPath, "../../ttyUSB0") // Target: /dev/ttyUSB0 + // /sys/class/tty/ttyUSB0/device symlink + sysTTYDeviceLink := "/sys/class/tty/ttyUSB0/device" + usbDeviceSysfsDir := "/sys/devices/pci0/usb1/1-1" // Assumed parent containing VID/PID + mfs.addSymlink(sysTTYDeviceLink, filepath.Join(usbDeviceSysfsDir, "1-1:1.0")) // -> /sys/devices/pci0/usb1/1-1/1-1:1.0 + // VID/PID/Serial files + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idVendor"), "0403\n") + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idProduct"), "6001\n") + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "serial"), "SERIAL123\n") + }, + expected: []SerialDeviceInfo{ + {Vid: "0403", Pid: "6001", SerialNumber: "SERIAL123", Port: filepath.Join(byIDPath, "usb-MyCorp_MyDevice_SERIAL123-if00-port0")}, + }, + }, + { + name: "Single device, matches VID/PID filter", + vidFilter: "0403", pidFilter: "6001", + setupMock: func(mfs *mockFileSystemReader) { + mfs.addDirEntry(byIDPath, &mockDirEntry{name: "usb-VID0403_PID6001_SERIAL123-if00-port0", mode: fs.ModeSymlink}) + byIDSymlinkPath := filepath.Join(byIDPath, "usb-VID0403_PID6001_SERIAL123-if00-port0") + mfs.addSymlink(byIDSymlinkPath, "../../ttyUSB0") + sysTTYDeviceLink := "/sys/class/tty/ttyUSB0/device" + usbDeviceSysfsDir := "/sys/devices/pci0/usb1/1-1" + mfs.addSymlink(sysTTYDeviceLink, filepath.Join(usbDeviceSysfsDir, "1-1:1.0")) + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idVendor"), "0403") + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idProduct"), "6001") + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "serial"), "SERIAL123") + }, + expected: []SerialDeviceInfo{ + {Vid: "0403", Pid: "6001", SerialNumber: "SERIAL123", Port: filepath.Join(byIDPath, "usb-VID0403_PID6001_SERIAL123-if00-port0")}, + }, + }, + { + name: "Single device, VID filter mismatch", + vidFilter: "FFFF", pidFilter: "6001", + setupMock: func(mfs *mockFileSystemReader) { + mfs.addDirEntry(byIDPath, &mockDirEntry{name: "usb-VID0403_PID6001_SERIAL123-if00-port0", mode: fs.ModeSymlink}) + byIDSymlinkPath := filepath.Join(byIDPath, "usb-VID0403_PID6001_SERIAL123-if00-port0") + mfs.addSymlink(byIDSymlinkPath, "../../ttyUSB0") + sysTTYDeviceLink := "/sys/class/tty/ttyUSB0/device" + usbDeviceSysfsDir := "/sys/devices/pci0/usb1/1-1" + mfs.addSymlink(sysTTYDeviceLink, filepath.Join(usbDeviceSysfsDir, "1-1:1.0")) + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idVendor"), "0403") + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idProduct"), "6001") + }, + expected: []SerialDeviceInfo{}, + }, + { + name: "Single device, PID filter mismatch", + vidFilter: "0403", pidFilter: "FFFF", + setupMock: func(mfs *mockFileSystemReader) { + mfs.addDirEntry(byIDPath, &mockDirEntry{name: "usb-VID0403_PID6001_SERIAL123-if00-port0", mode: fs.ModeSymlink}) + byIDSymlinkPath := filepath.Join(byIDPath, "usb-VID0403_PID6001_SERIAL123-if00-port0") + mfs.addSymlink(byIDSymlinkPath, "../../ttyUSB0") + sysTTYDeviceLink := "/sys/class/tty/ttyUSB0/device" + usbDeviceSysfsDir := "/sys/devices/pci0/usb1/1-1" + mfs.addSymlink(sysTTYDeviceLink, filepath.Join(usbDeviceSysfsDir, "1-1:1.0")) + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idVendor"), "0403") + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idProduct"), "6001") + }, + expected: []SerialDeviceInfo{}, + }, + { + name: "EvalSymlinks error for by-id symlink", + setupMock: func(mfs *mockFileSystemReader) { + byIDSymlinkPath := filepath.Join(byIDPath, "usb-SomeDevice-if00") + mfs.addDirEntry(byIDPath, &mockDirEntry{name: "usb-SomeDevice-if00", mode: fs.ModeSymlink}) + mfs.setEvalSymlinksError(byIDSymlinkPath, errors.New("eval error for by-id link")) + }, + expected: []SerialDeviceInfo{}, // Skips this device + }, + { + name: "ReadFile error for idVendor", + setupMock: func(mfs *mockFileSystemReader) { + mfs.addDirEntry(byIDPath, &mockDirEntry{name: "usb-VID0403_PID6001_SERIAL123-if00-port0", mode: fs.ModeSymlink}) + byIDSymlinkPath := filepath.Join(byIDPath, "usb-VID0403_PID6001_SERIAL123-if00-port0") + mfs.addSymlink(byIDSymlinkPath, "../../ttyUSB0") + sysTTYDeviceLink := "/sys/class/tty/ttyUSB0/device" + usbDeviceSysfsDir := "/sys/devices/pci0/usb1/1-1" + mfs.addSymlink(sysTTYDeviceLink, filepath.Join(usbDeviceSysfsDir, "1-1:1.0")) + mfs.setReadFileError(filepath.Join(usbDeviceSysfsDir, "idVendor"), errors.New("read idVendor error")) + // idProduct is still there, but ReadFile for idVendor fails + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idProduct"), "6001") + }, + wantErr: true, // Expect error because reading idVendor is critical + }, + { + name: "Device with missing serial file", + setupMock: func(mfs *mockFileSystemReader) { + mfs.addDirEntry(byIDPath, &mockDirEntry{name: "usb-VID0403_PID6001_NOSERIAL-if00-port0", mode: fs.ModeSymlink}) + byIDSymlinkPath := filepath.Join(byIDPath, "usb-VID0403_PID6001_NOSERIAL-if00-port0") + mfs.addSymlink(byIDSymlinkPath, "../../ttyUSB0") + sysTTYDeviceLink := "/sys/class/tty/ttyUSB0/device" + usbDeviceSysfsDir := "/sys/devices/pci0/usb1/1-1" + mfs.addSymlink(sysTTYDeviceLink, filepath.Join(usbDeviceSysfsDir, "1-1:1.0")) + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idVendor"), "0403") + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idProduct"), "6001") + mfs.setReadFileError(filepath.Join(usbDeviceSysfsDir, "serial"), os.ErrNotExist) // Serial file does not exist + }, + expected: []SerialDeviceInfo{ + {Vid: "0403", Pid: "6001", SerialNumber: "", Port: filepath.Join(byIDPath, "usb-VID0403_PID6001_NOSERIAL-if00-port0")}, + }, + }, + { + name: "VID/PID filter case insensitivity", + vidFilter: "0a1b", pidFilter: "0c2d", // Filter with lowercase + setupMock: func(mfs *mockFileSystemReader) { + mfs.addDirEntry(byIDPath, &mockDirEntry{name: "usb-VID0A1B_PID0C2D_SERIALXYZ-if00-port0", mode: fs.ModeSymlink}) + byIDSymlinkPath := filepath.Join(byIDPath, "usb-VID0A1B_PID0C2D_SERIALXYZ-if00-port0") + mfs.addSymlink(byIDSymlinkPath, "../../ttyUSB0") + sysTTYDeviceLink := "/sys/class/tty/ttyUSB0/device" + usbDeviceSysfsDir := "/sys/devices/pci0/usb1/1-1" + mfs.addSymlink(sysTTYDeviceLink, filepath.Join(usbDeviceSysfsDir, "1-1:1.0")) + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idVendor"), "0A1B") // Device has uppercase + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "idProduct"), "0C2D") + mfs.addFile(filepath.Join(usbDeviceSysfsDir, "serial"), "SERIALXYZ") + }, + expected: []SerialDeviceInfo{ + {Vid: "0A1B", Pid: "0C2D", SerialNumber: "SERIALXYZ", Port: filepath.Join(byIDPath, "usb-VID0A1B_PID0C2D_SERIALXYZ-if00-port0")}, + }, + }, + { + name: "Multiple devices, one matching filter", + vidFilter: "1A86", pidFilter: "7523", + setupMock: func(mfs *mockFileSystemReader) { + // Device 1 (matches) + mfs.addDirEntry(byIDPath, &mockDirEntry{name: "usb-QinHeng_Electronics_CH340_SERIAL_MATCH-if00-port0", mode: fs.ModeSymlink}) + mfs.addSymlink(filepath.Join(byIDPath, "usb-QinHeng_Electronics_CH340_SERIAL_MATCH-if00-port0"), "../../ttyUSB0") + mfs.addSymlink("/sys/class/tty/ttyUSB0/device", "/sys/devices/pci0/usb1/1-1/1-1:1.0") + mfs.addFile("/sys/devices/pci0/usb1/1-1/idVendor", "1A86") + mfs.addFile("/sys/devices/pci0/usb1/1-1/idProduct", "7523") + mfs.addFile("/sys/devices/pci0/usb1/1-1/serial", "SERIAL_MATCH") + + // Device 2 (does not match) + mfs.addDirEntry(byIDPath, &mockDirEntry{name: "usb-FTDI_FT232R_USB_UART_SERIAL_NOMATCH-if00-port0", mode: fs.ModeSymlink}) + mfs.addSymlink(filepath.Join(byIDPath, "usb-FTDI_FT232R_USB_UART_SERIAL_NOMATCH-if00-port0"), "../../ttyUSB1") + mfs.addSymlink("/sys/class/tty/ttyUSB1/device", "/sys/devices/pci0/usb1/1-2/1-2:1.0") + mfs.addFile("/sys/devices/pci0/usb1/1-2/idVendor", "0403") + mfs.addFile("/sys/devices/pci0/usb1/1-2/idProduct", "6001") + mfs.addFile("/sys/devices/pci0/usb1/1-2/serial", "SERIAL_NOMATCH") + }, + expected: []SerialDeviceInfo{ + {Vid: "1A86", Pid: "7523", SerialNumber: "SERIAL_MATCH", Port: filepath.Join(byIDPath, "usb-QinHeng_Electronics_CH340_SERIAL_MATCH-if00-port0")}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mfs := newMockFileSystemReader() + tt.setupMock(mfs) + + devices, err := getSerialDevicesWithReader(tt.vidFilter, tt.pidFilter, mfs) + + if (err != nil) != tt.wantErr { + t.Fatalf("getSerialDevicesWithReader() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && !reflect.DeepEqual(devices, tt.expected) { + // For easier debugging of slice differences + expectedStr := fmt.Sprintf("%+v", tt.expected) + gotStr := fmt.Sprintf("%+v", devices) + // Compare string representations for easier visual diff if DeepEqual fails + if expectedStr != gotStr { + t.Errorf("getSerialDevicesWithReader() mismatch:\nExpected: %s\nGot: %s\n--- Raw Expected ---\n%+v\n--- Raw Got ---\n%+v", expectedStr, gotStr, tt.expected, devices) + } else { // If string representations match but DeepEqual failed, it might be due to nil vs empty slice + if len(tt.expected) == 0 && len(devices) == 0 { + // This is fine, both are effectively "no devices" + } else { + t.Errorf("getSerialDevicesWithReader() DeepEqual failed. Expected: %+v, Got: %+v", tt.expected, devices) + } + } + } else if !tt.wantErr && tt.expected == nil && len(devices) == 0 { + // Special case: if tt.expected is nil and devices is empty slice, it's a match + // This handles the case where an empty slice is expected, and reflect.DeepEqual(nil, []T{}) is false + } else if !tt.wantErr && len(tt.expected) == 0 && devices == nil { + // Special case: if tt.expected is empty slice and devices is nil, it's a match + } + }) + } +} + +func TestFindSerialDeviceInfoDirWithReader(t *testing.T) { + t.Helper() + // Base path for tty devices, e.g., /dev/ttyUSB0 + const ttyDevicePath = "/dev/ttyUSB0" + // Path to the 'device' symlink in sysfs for this tty device + sysTTYDeviceLink := filepath.Join("/sys/class/tty", filepath.Base(ttyDevicePath), "device") + + tests := []struct { + name string + setupMock func(*mockFileSystemReader) + expected string // Expected path to the USB device directory, or "" if not found + }{ + { + name: "Found in current dir (pointed by 'device' symlink)", + setupMock: func(mfs *mockFileSystemReader) { + // /sys/class/tty/ttyUSB0/device -> /sys/devices/pci0/usb1/1-1/1-1:1.0 + mfs.addSymlink(sysTTYDeviceLink, "/sys/devices/pci0/usb1/1-1/1-1:1.0") + // Mock idVendor/idProduct directly in /sys/devices/pci0/usb1/1-1/1-1:1.0 + mfs.mockStats[filepath.Join("/sys/devices/pci0/usb1/1-1/1-1:1.0", "idVendor")] = &mockFileInfo{} + mfs.mockStats[filepath.Join("/sys/devices/pci0/usb1/1-1/1-1:1.0", "idProduct")] = &mockFileInfo{} + }, + expected: "/sys/devices/pci0/usb1/1-1/1-1:1.0", + }, + { + name: "Found in parent dir", + setupMock: func(mfs *mockFileSystemReader) { + // /sys/class/tty/ttyUSB0/device -> /sys/devices/pci0/usb1/1-1/1-1:1.0 + mfs.addSymlink(sysTTYDeviceLink, "/sys/devices/pci0/usb1/1-1/1-1:1.0") + // idVendor/idProduct not in 1-1:1.0, but in 1-1 + mfs.setStatError(filepath.Join("/sys/devices/pci0/usb1/1-1/1-1:1.0", "idVendor"), os.ErrNotExist) // So it fails check current + mfs.setStatError(filepath.Join("/sys/devices/pci0/usb1/1-1/1-1:1.0", "idProduct"), os.ErrNotExist) + + mfs.mockStats[filepath.Join("/sys/devices/pci0/usb1/1-1", "idVendor")] = &mockFileInfo{} + mfs.mockStats[filepath.Join("/sys/devices/pci0/usb1/1-1", "idProduct")] = &mockFileInfo{} + }, + expected: "/sys/devices/pci0/usb1/1-1", + }, + { + name: "Found in grandparent dir", + setupMock: func(mfs *mockFileSystemReader) { + // /sys/class/tty/ttyUSB0/device -> /sys/devices/pci0/usb1/1-1/1-1:1.0/tty/ttyUSB0 + mfs.addSymlink(sysTTYDeviceLink, "/sys/devices/pci0/usb1/1-1/1-1:1.0/tty/ttyUSB0") + // idVendor/idProduct not in .../ttyUSB0 or .../1-1:1.0, but in .../1-1 + mfs.setStatError(filepath.Join("/sys/devices/pci0/usb1/1-1/1-1:1.0/tty/ttyUSB0", "idVendor"), os.ErrNotExist) + mfs.setStatError(filepath.Join("/sys/devices/pci0/usb1/1-1/1-1:1.0/tty/ttyUSB0", "idProduct"), os.ErrNotExist) + mfs.setStatError(filepath.Join("/sys/devices/pci0/usb1/1-1/1-1:1.0", "idVendor"), os.ErrNotExist) + mfs.setStatError(filepath.Join("/sys/devices/pci0/usb1/1-1/1-1:1.0", "idProduct"), os.ErrNotExist) + + mfs.mockStats[filepath.Join("/sys/devices/pci0/usb1/1-1", "idVendor")] = &mockFileInfo{} + mfs.mockStats[filepath.Join("/sys/devices/pci0/usb1/1-1", "idProduct")] = &mockFileInfo{} + }, + expected: "/sys/devices/pci0/usb1/1-1", + }, + { + name: "Not found - VID/PID files do not exist in hierarchy", + setupMock: func(mfs *mockFileSystemReader) { + mfs.addSymlink(sysTTYDeviceLink, "/sys/devices/pci0/usb1/1-1/1-1:1.0/tty/ttyUSB0") + mfs.setStatError(filepath.Join("/sys/devices/pci0/usb1/1-1/1-1:1.0/tty/ttyUSB0", "idVendor"), os.ErrNotExist) + mfs.setStatError(filepath.Join("/sys/devices/pci0/usb1/1-1/1-1:1.0", "idVendor"), os.ErrNotExist) + mfs.setStatError(filepath.Join("/sys/devices/pci0/usb1/1-1", "idVendor"), os.ErrNotExist) + // No need to mock idProduct if idVendor is already not found for all relevant paths + }, + expected: "", + }, + { + name: "EvalSymlinks error for sysTTYDeviceLink", + setupMock: func(mfs *mockFileSystemReader) { + mfs.setEvalSymlinksError(sysTTYDeviceLink, errors.New("eval symlink failed")) + }, + expected: "", + }, + { + name: "Pathological grandparent (avoid going to . or /)", + setupMock: func(mfs *mockFileSystemReader) { + // Sys tty path /sys/class/tty/ttyS0/device -> /sys/devices/platform/serial8250/tty/ttyS0 + // This structure might mean idVendor/idProduct are not found in typical USB-like parent/grandparent. + localSysTTYDeviceLink := "/sys/class/tty/ttyS0/device" + targetPath := "/sys/devices/platform/serial8250/tty/ttyS0" // No "usb" like paths here + mfs.addSymlink(localSysTTYDeviceLink, targetPath) + + // Assume idVendor/idProduct are not found anywhere up this path. + mfs.setStatError(filepath.Join(targetPath, "idVendor"), os.ErrNotExist) + mfs.setStatError(filepath.Join(filepath.Dir(targetPath), "idVendor"), os.ErrNotExist) + mfs.setStatError(filepath.Join(filepath.Dir(filepath.Dir(targetPath)), "idVendor"), os.ErrNotExist) + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mfs := newMockFileSystemReader() + tt.setupMock(mfs) + + // For the pathological grandparent test, use a different ttyDevicePath + currentTTYDevicePath := ttyDevicePath + if tt.name == "Pathological grandparent (avoid going to . or /)" { + currentTTYDevicePath = "/dev/ttyS0" + } + + got := findSerialDeviceInfoDirWithReader(currentTTYDevicePath, mfs) + if got != tt.expected { + t.Errorf("findSerialDeviceInfoDirWithReader() = %q, want %q", got, tt.expected) + } + }) + } +} + +func TestCheckForVIDPIDFilesWithReader(t *testing.T) { + t.Helper() + tests := []struct { + name string + setupMock func(*mockFileSystemReader) + dirPath string + expected bool + }{ + { + name: "VID and PID exist", + setupMock: func(mfs *mockFileSystemReader) { + mfs.mockStats[filepath.Join("/sys/test_device", "idVendor")] = &mockFileInfo{name: "idVendor"} + mfs.mockStats[filepath.Join("/sys/test_device", "idProduct")] = &mockFileInfo{name: "idProduct"} + }, + dirPath: "/sys/test_device", + expected: true, + }, + { + name: "idVendor missing", + setupMock: func(mfs *mockFileSystemReader) { + mfs.setStatError(filepath.Join("/sys/test_device", "idVendor"), os.ErrNotExist) + mfs.mockStats[filepath.Join("/sys/test_device", "idProduct")] = &mockFileInfo{name: "idProduct"} + }, + dirPath: "/sys/test_device", + expected: false, + }, + { + name: "idProduct missing", + setupMock: func(mfs *mockFileSystemReader) { + mfs.mockStats[filepath.Join("/sys/test_device", "idVendor")] = &mockFileInfo{name: "idVendor"} + mfs.setStatError(filepath.Join("/sys/test_device", "idProduct"), os.ErrNotExist) + }, + dirPath: "/sys/test_device", + expected: false, + }, + { + name: "Both idVendor and idProduct missing", + setupMock: func(mfs *mockFileSystemReader) { + mfs.setStatError(filepath.Join("/sys/test_device", "idVendor"), os.ErrNotExist) + mfs.setStatError(filepath.Join("/sys/test_device", "idProduct"), os.ErrNotExist) + }, + dirPath: "/sys/test_device", + expected: false, + }, + { + name: "Stat returns other error for idVendor", + setupMock: func(mfs *mockFileSystemReader) { + mfs.setStatError(filepath.Join("/sys/test_device", "idVendor"), errors.New("some stat error")) + mfs.mockStats[filepath.Join("/sys/test_device", "idProduct")] = &mockFileInfo{name: "idProduct"} + }, + dirPath: "/sys/test_device", + expected: false, + }, + { + name: "Stat returns other error for idProduct", + setupMock: func(mfs *mockFileSystemReader) { + mfs.mockStats[filepath.Join("/sys/test_device", "idVendor")] = &mockFileInfo{name: "idVendor"} + mfs.setStatError(filepath.Join("/sys/test_device", "idProduct"), errors.New("some stat error")) + }, + dirPath: "/sys/test_device", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mfs := newMockFileSystemReader() + if tt.setupMock != nil { + tt.setupMock(mfs) + } + got := checkForVIDPIDFilesWithReader(tt.dirPath, mfs) + if got != tt.expected { + t.Errorf("checkForVIDPIDFilesWithReader() = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/serialfinder_windows.go b/serialfinder_windows.go index 40d8a1d..dd63c62 100644 --- a/serialfinder_windows.go +++ b/serialfinder_windows.go @@ -5,104 +5,238 @@ package serialfinder import ( "fmt" + "regexp" "strings" "syscall" "golang.org/x/sys/windows/registry" ) -// GetSerialDevices retrieves USB devices on Windows, filtering by VID and PID, and finds the corresponding COM port -func GetSerialDevices(vid, pid string) ([]SerialDeviceInfo, error) { - var devices []SerialDeviceInfo +// registryKey is an interface wrapper for registry.Key methods used. +type registryKey interface { + ReadSubKeyNames(n int) ([]string, error) + GetStringValue(name string) (string, uint32, error) + Close() error +} + +// defaultRegistryKey wraps a real registry.Key to satisfy the registryKey interface. +type defaultRegistryKey struct { + registry.Key +} + +func (drk *defaultRegistryKey) ReadSubKeyNames(n int) ([]string, error) { + return drk.Key.ReadSubKeyNames(n) +} + +func (drk *defaultRegistryKey) GetStringValue(name string) (string, uint32, error) { + return drk.Key.GetStringValue(name) +} + +func (drk *defaultRegistryKey) Close() error { + return drk.Key.Close() +} + +// registryHandler abstracts registry opening operations. +type registryHandler interface { + OpenKey(base registry.Key, path string, access uint32) (registryKey, error) +} + +// defaultRegistryHandler is the default implementation using the actual registry. +type defaultRegistryHandler struct{} - // Open the registry key for USB devices - key, err := registry.OpenKey(registry.LOCAL_MACHINE, `SYSTEM\CurrentControlSet\Enum\USB`, registry.READ) +func (drh *defaultRegistryHandler) OpenKey(base registry.Key, path string, access uint32) (registryKey, error) { + k, err := registry.OpenKey(base, path, access) if err != nil { return nil, err } + return &defaultRegistryKey{Key: k}, nil +} + +// portCheckerFunc defines the signature for functions that check if a COM port is active. +type portCheckerFunc func(portName string) bool + +// checkPortActive is a variable holding the current port checking function. +// This allows it to be replaced during testing. +var checkPortActive = checkCOMPortActiveWindows + +var ( + vidRegex = regexp.MustCompile(`VID_([0-9a-fA-F]{4})`) + pidRegex = regexp.MustCompile(`PID_([0-9a-fA-F]{4})`) +) + +// GetSerialDevices is the public function to retrieve USB devices on Windows. +// It uses the default registry handler and port checker. +func GetSerialDevices(vidFilter, pidFilter string) ([]SerialDeviceInfo, error) { + return getSerialDevicesWithRegistry(vidFilter, pidFilter, &defaultRegistryHandler{}, checkPortActive) +} + +// getSerialDevicesWithRegistry is the internal implementation allowing for custom registry handling and port checking. +func getSerialDevicesWithRegistry(vidFilter, pidFilter string, rh registryHandler, portCheck portCheckerFunc) ([]SerialDeviceInfo, error) { + var devices []SerialDeviceInfo + + targetVidUpper := strings.ToUpper(vidFilter) + targetPidUpper := strings.ToUpper(pidFilter) + + // The baseKey is effectively registry.LOCAL_MACHINE, but OpenKey in registryHandler takes registry.Key + // So we need to open the initial Enum\USB key here before passing it to the loop that uses rh.OpenKey for subkeys. + // This is a bit awkward. A cleaner way might be for registryHandler.OpenKey to handle predefined keys + // or for the first key to be opened outside and then its subkeys opened via rh.OpenKey. + // For now, let's open the EnumUSB key directly and then use rh for its children. + // This means the mock for rh.OpenKey will operate on sub-paths of Enum\USB. + + enumUSBPath := `SYSTEM\CurrentControlSet\Enum\USB` + enumUSBKeyHandle, err := registry.OpenKey(registry.LOCAL_MACHINE, enumUSBPath, registry.READ) + if err != nil { + return nil, fmt.Errorf("failed to open USB enumeration registry key LKM\\%s: %w", enumUSBPath, err) + } + // Wrap the initially opened key so its methods (ReadSubKeyNames, Close) are called on the real key. + // The registryKey interface is primarily for keys *returned by* rh.OpenKey. + // This is still a bit mixed. Let's assume rh.OpenKey can handle opening the first key too. + // To do this, rh.OpenKey needs to accept nil or a specific marker for LOCAL_MACHINE. + // Or, the path passed to rh.OpenKey includes the top-level (e.g. "LKM\\SYSTEM\\..."). + // Let's refine registryHandler to make OpenKey more flexible or add a method for base key. + // For this iteration, we'll assume rh.OpenKey is for subkeys OF an already opened key. + // So, the `key` variable below will be the real `registry.Key` for `Enum\USB`. + + // Re-evaluating: The `registryHandler`'s `OpenKey` takes `base registry.Key`. + // So, for the first call, `base` is `registry.LOCAL_MACHINE`. + // For subsequent calls, `base` is the key returned by the previous `OpenKey` call (wrapped). + // This means `defaultRegistryKey` needs to expose its underlying `registry.Key` or + // `registryHandler.OpenKey` needs to accept `registryKey` as base. + // Let's make `registryKey` expose its underlying `registry.Key` if it's a `defaultRegistryKey`. + + // Simpler: Let registryHandler.OpenKey take the full path from a known root if base is nil, + // or path relative to base if base is not nil. + // For now, the interface is `OpenKey(base registry.Key, path string, access uint32)`. + // So, the first call: rh.OpenKey(registry.LOCAL_MACHINE, `SYSTEM\CurrentControlSet\Enum\USB`, ...) + // Subsequent calls: rh.OpenKey(parentKey.(actual_type).Key, subPath, ...) -> this is messy. + + // Cleanest approach for interface: + // registryHandler.OpenTopLevelKey(path string, access uint32) (registryKey, error) + // registryKey.OpenSubKey(path string, access uint32) (registryKey, error) - this is better. + // + // Sticking to current plan for now: + // Top-level key: + key, err := rh.OpenKey(registry.LOCAL_MACHINE, `SYSTEM\CurrentControlSet\Enum\USB`, registry.READ) + if err != nil { + return nil, fmt.Errorf("failed to open USB enumeration registry key: %w", err) + } defer key.Close() - // Read the list of subkeys (device IDs) - deviceIDs, err := key.ReadSubKeyNames(-1) + deviceInstanceIDs, err := key.ReadSubKeyNames(-1) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read USB device instance IDs: %w", err) } - // Iterate over each device ID - for _, deviceID := range deviceIDs { - // Check if the deviceID contains the specified VID and PID - if strings.Contains(deviceID, fmt.Sprintf("VID_%s&PID_%s", vid, pid)) { - deviceKey, err := registry.OpenKey(key, deviceID, registry.READ) - if err != nil { - continue - } - defer deviceKey.Close() + for _, deviceInstanceID := range deviceInstanceIDs { + vidMatches := vidRegex.FindStringSubmatch(deviceInstanceID) + var actualVid string + if len(vidMatches) > 1 { + actualVid = strings.ToUpper(vidMatches[1]) + } else { + continue + } - // Read the list of subkeys under each device ID (which usually include serial numbers) - serials, err := deviceKey.ReadSubKeyNames(-1) + pidMatches := pidRegex.FindStringSubmatch(deviceInstanceID) + var actualPid string + if len(pidMatches) > 1 { + actualPid = strings.ToUpper(pidMatches[1]) + } else { + continue + } + + if targetVidUpper != "" && actualVid != targetVidUpper { + continue + } + if targetPidUpper != "" && actualPid != targetPidUpper { + continue + } + + // Open the specific device instance key. Base is 'key' (Enum\USB). + // This assumes 'key' obtained from rh.OpenKey can be used as a 'base' for another rh.OpenKey. + // This implies that the registryKey interface needs to be usable as a registry.Key for the base argument. + // This is where the design gets tricky. The `base` in `registry.OpenKey` is a concrete `registry.Key`. + // The `registryKey` interface would hide this. + // A simple fix: defaultRegistryKey holds registry.Key, and we cast if needed by defaultRegistryHandler. + // Or, the handler is responsible for all openings. + // Let's pass the *path* to the device instance to iterateSerialsWindowsWithRegistry, + // and it will use rh.OpenKey(registry.LOCAL_MACHINE, fullPathToInstance, ...) + + // Path to device instance: SYSTEM\CurrentControlSet\Enum\USB\ + fullDeviceInstancePath := fmt.Sprintf(`SYSTEM\CurrentControlSet\Enum\USB\%s`, deviceInstanceID) + + // The subkeys (serial numbers) are read from the deviceInstanceKey itself. + // So, we need to open deviceInstanceKey first. + deviceInstanceRegKey, err := rh.OpenKey(registry.LOCAL_MACHINE, fullDeviceInstancePath, registry.READ) + if err != nil { + continue + } + // Defer needs to be inside the loop for keys opened in loop + func() { + defer deviceInstanceRegKey.Close() + instanceSubKeyNames, err := deviceInstanceRegKey.ReadSubKeyNames(-1) if err != nil { - continue + return // continue outer loop } - // Iterate over each serial number - for _, serial := range serials { - device := iterateSerialsWindows(serial, deviceID, key) - if device != (SerialDeviceInfo{}) { // Append only if the device is active + for _, instanceSubKeyName := range instanceSubKeyNames { + // Path to "Device Parameters" key: SYSTEM\CurrentControlSet\Enum\USB\\\Device Parameters + deviceParamsPath := fmt.Sprintf(`%s\%s\Device Parameters`, fullDeviceInstancePath, instanceSubKeyName) + + device := iterateSerialsWindowsWithRegistry( + instanceSubKeyName, deviceInstanceID, actualVid, actualPid, + deviceParamsPath, rh, portCheck, + ) + if device.Port != "" { devices = append(devices, device) } } - } + }() // Anonymous function for defer scoping } - return devices, nil } -// Helper function to iterate over serials and get the corresponding COM ports on Windows. -func iterateSerialsWindows(serial, deviceID string, key registry.Key) SerialDeviceInfo { - // Open the `Device Parameters` key to find the COM port - deviceParamsKeyPath := fmt.Sprintf(`%s\%s\Device Parameters`, deviceID, serial) - deviceParamsKey, err := registry.OpenKey(key, deviceParamsKeyPath, registry.READ) +// iterateSerialsWindowsWithRegistry is the testable helper function. +// deviceParamsRegistryPath is the full path from LOCAL_MACHINE to the "Device Parameters" key. +func iterateSerialsWindowsWithRegistry( + serialNumber, deviceInstanceID, vid, pid string, + deviceParamsRegistryPath string, + rh registryHandler, portCheck portCheckerFunc, +) SerialDeviceInfo { + + deviceParamsKey, err := rh.OpenKey(registry.LOCAL_MACHINE, deviceParamsRegistryPath, registry.READ) if err != nil { return SerialDeviceInfo{} } defer deviceParamsKey.Close() - // Read the `PortName` value, which should contain the COM port portName, _, err := deviceParamsKey.GetStringValue("PortName") if err != nil { return SerialDeviceInfo{} } - // Check if the COM port can be opened to determine if the device is active - isActive := checkCOMPortActiveWindows(portName) - if !isActive { + if !portCheck(portName) { return SerialDeviceInfo{} } return SerialDeviceInfo{ - SerialNumber: serial, - Vid: strings.Split(deviceID, "&")[0][4:], - Pid: strings.Split(deviceID, "&")[1][4:], + SerialNumber: serialNumber, + Vid: vid, + Pid: pid, Port: portName, } } -// checkCOMPortActiveWindows tries to open the COM port to check if it is active on Windows +// checkCOMPortActiveWindows tries to open the COM port to check if it is active on Windows. func checkCOMPortActiveWindows(portName string) bool { comPort := fmt.Sprintf("\\\\.\\%s", portName) handle, err := syscall.CreateFile( syscall.StringToUTF16Ptr(comPort), syscall.GENERIC_READ|syscall.GENERIC_WRITE, - 0, - nil, - syscall.OPEN_EXISTING, - 0, - 0, - ) + 0, nil, syscall.OPEN_EXISTING, 0, 0) if err != nil { return false } defer syscall.CloseHandle(handle) - return true } diff --git a/serialfinder_windows_test.go b/serialfinder_windows_test.go new file mode 100644 index 0000000..ee6f96b --- /dev/null +++ b/serialfinder_windows_test.go @@ -0,0 +1,428 @@ +//go:build windows +// +build windows + +package serialfinder + +import ( + "errors" + "fmt" + "reflect" + "regexp" // For TestVidPidRegex if not already imported by main file for test file + "strings" + "testing" + + "golang.org/x/sys/windows/registry" // For registry.Key constants like LOCAL_MACHINE +) + +// mockRegistryKey implements the registryKey interface for testing. +type mockRegistryKey struct { + subKeyNamesToReturn []string + subKeyNamesError error + stringValueToReturn string + stringTypeToReturn uint32 + stringValueError error + closeError error + name string // For debugging or identification +} + +func (mrk *mockRegistryKey) ReadSubKeyNames(n int) ([]string, error) { + return mrk.subKeyNamesToReturn, mrk.subKeyNamesError +} + +func (mrk *mockRegistryKey) GetStringValue(name string) (string, uint32, error) { + // Could add logic here to return different strings based on 'name' if needed + return mrk.stringValueToReturn, mrk.stringTypeToReturn, mrk.stringValueError +} + +func (mrk *mockRegistryKey) Close() error { + return mrk.closeError +} + +// mockRegistryHandler implements the registryHandler interface for testing. +type mockRegistryHandler struct { + // mockKeys maps a full path (string) to a mockRegistryKey or an error + mockKeys map[string]*mockRegistryKey + openKeyError map[string]error // Specific error for a path + genericOpenKeyError error // Generic error if path not in openKeyError +} + +func newMockRegistryHandler() *mockRegistryHandler { + return &mockRegistryHandler{ + mockKeys: make(map[string]*mockRegistryKey), + openKeyError: make(map[string]error), + } +} + +func (mrh *mockRegistryHandler) OpenKey(base registry.Key, path string, access uint32) (registryKey, error) { + // In tests, base is usually registry.LOCAL_MACHINE. We'll use the path as the key for mocks. + // A real implementation might need to combine base and path for uniqueness if base varies. + fullPath := path // Assuming path is unique enough for mock map key + // For more complex scenarios, one might create a unique key from base and path. + + if err, exists := mrh.openKeyError[fullPath]; exists { + return nil, err + } + if mrh.genericOpenKeyError != nil { + return nil, mrh.genericOpenKeyError + } + + key, ok := mrh.mockKeys[fullPath] + if !ok { + return nil, fmt.Errorf("mockRegistryHandler: unmocked path %s", fullPath) // Or registry.ErrNotExist + } + return key, nil +} + +// Helper to add a mock key to the handler +func (mrh *mockRegistryHandler) addMockKey(path string, key *mockRegistryKey) { + mrh.mockKeys[path] = key + key.name = path // Store path in key for easier debugging if needed +} + +// Helper to set an error for a specific OpenKey path +func (mrh *mockRegistryHandler) setOpenKeyError(path string, err error) { + mrh.openKeyError[path] = err +} + + +// mockPortChecker is a utility to create a portCheckerFunc for tests. +func mockPortChecker(shouldBeActive bool) portCheckerFunc { + return func(portName string) bool { + return shouldBeActive + } +} + +func TestVidPidRegex(t *testing.T) { + t.Helper() + tests := []struct { + name string + deviceID string + wantVID string + wantPID string + vidShouldMatch bool + pidShouldMatch bool + }{ + { + name: "Standard USB VID/PID", + deviceID: `USB\VID_1A86&PID_7523\CH340SERIAL`, + wantVID: "1A86", + wantPID: "7523", + vidShouldMatch: true, + pidShouldMatch: true, + }, + { + name: "FTDI Bus VID/PID", + deviceID: `FTDIBUS\VID_0403+PID_6001+A50285BI\0000`, + wantVID: "0403", + wantPID: "6001", + vidShouldMatch: true, + pidShouldMatch: true, + }, + { + name: "VID/PID with lowercase hex", + deviceID: `USB\VID_abcd&PID_ef01\SERIAL`, + wantVID: "abcd", + wantPID: "ef01", + vidShouldMatch: true, + pidShouldMatch: true, + }, + { + name: "Only VID present", + deviceID: `USB\VID_1234\NoPID`, + wantVID: "1234", + wantPID: "", + vidShouldMatch: true, + pidShouldMatch: false, + }, + { + name: "Only PID present (malformed but test regex)", + deviceID: `USB\Something&PID_5678\NoVID`, + wantVID: "", + wantPID: "5678", + vidShouldMatch: false, + pidShouldMatch: true, + }, + { + name: "No VID or PID", + deviceID: `USB\SomethingElse\AnotherThing`, + wantVID: "", + wantPID: "", + vidShouldMatch: false, + pidShouldMatch: false, + }, + { + name: "Malformed VID (too short)", + deviceID: `USB\VID_123&PID_5678\Serial`, + wantVID: "", + wantPID: "5678", + vidShouldMatch: false, + pidShouldMatch: true, + }, + { + name: "Malformed PID (non-hex)", + deviceID: `USB\VID_1234&PID_GHIJ\Serial`, + wantVID: "1234", + wantPID: "", + vidShouldMatch: true, + pidShouldMatch: false, + }, + { + name: "Empty Device ID", + deviceID: ``, + wantVID: "", + wantPID: "", + vidShouldMatch: false, + pidShouldMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test VID + vidMatches := vidRegex.FindStringSubmatch(tt.deviceID) + if tt.vidShouldMatch { + if len(vidMatches) < 2 { + t.Errorf("vidRegex did not find VID, but expected one in %q", tt.deviceID) + } else if vidMatches[1] != tt.wantVID { + t.Errorf("vidRegex got VID %q, want %q from %q", vidMatches[1], tt.wantVID, tt.deviceID) + } + } else { + if len(vidMatches) > 1 { + t.Errorf("vidRegex found VID %q, but expected none in %q", vidMatches[1], tt.deviceID) + } + } + + // Test PID + pidMatches := pidRegex.FindStringSubmatch(tt.deviceID) + if tt.pidShouldMatch { + if len(pidMatches) < 2 { + t.Errorf("pidRegex did not find PID, but expected one in %q", tt.deviceID) + } else if pidMatches[1] != tt.wantPID { + t.Errorf("pidRegex got PID %q, want %q from %q", pidMatches[1], tt.wantPID, tt.deviceID) + } + } else { + if len(pidMatches) > 1 { + t.Errorf("pidRegex found PID %q, but expected none in %q", pidMatches[1], tt.deviceID) + } + } + }) + } +} + +func TestGetSerialDevicesWithRegistry(t *testing.T) { + t.Helper() + const enumUSBPath = `SYSTEM\CurrentControlSet\Enum\USB` + + tests := []struct { + name string + vidFilter string + pidFilter string + setupMock func(*mockRegistryHandler) + portChecker portCheckerFunc + expected []SerialDeviceInfo + wantErr bool + }{ + { + name: "OpenKey for EnumUSB fails", + setupMock: func(mrh *mockRegistryHandler) { + mrh.setOpenKeyError(enumUSBPath, errors.New("failed to open Enum\\USB")) + }, + portChecker: mockPortChecker(true), + wantErr: true, + }, + { + name: "EnumUSB ReadSubKeyNames fails", + setupMock: func(mrh *mockRegistryHandler) { + enumUSBKey := &mockRegistryKey{subKeyNamesError: errors.New("failed to read subkeys")} + mrh.addMockKey(enumUSBPath, enumUSBKey) + }, + portChecker: mockPortChecker(true), + wantErr: true, + }, + { + name: "No device instance IDs", + setupMock: func(mrh *mockRegistryHandler) { + enumUSBKey := &mockRegistryKey{subKeyNamesToReturn: []string{}} + mrh.addMockKey(enumUSBPath, enumUSBKey) + }, + portChecker: mockPortChecker(true), + expected: []SerialDeviceInfo{}, + }, + { + name: "Single device, no filter, port active", + vidFilter: "", pidFilter: "", + setupMock: func(mrh *mockRegistryHandler) { + deviceInstanceID := "VID_0403&PID_6001" + instancePath := enumUSBPath + `\` + deviceInstanceID + serialKeyName := "SERIAL123" + deviceParamsPath := instancePath + `\` + serialKeyName + `\Device Parameters` + + mrh.addMockKey(enumUSBPath, &mockRegistryKey{subKeyNamesToReturn: []string{deviceInstanceID}}) + mrh.addMockKey(instancePath, &mockRegistryKey{subKeyNamesToReturn: []string{serialKeyName}}) + mrh.addMockKey(deviceParamsPath, &mockRegistryKey{stringValueToReturn: "COM3"}) + }, + portChecker: mockPortChecker(true), + expected: []SerialDeviceInfo{ + {Vid: "0403", Pid: "6001", SerialNumber: "SERIAL123", Port: "COM3"}, + }, + }, + { + name: "Single device, matches VID/PID filter, port active", + vidFilter: "0403", pidFilter: "6001", + setupMock: func(mrh *mockRegistryHandler) { + deviceInstanceID := "VID_0403&PID_6001" + instancePath := enumUSBPath + `\` + deviceInstanceID + serialKeyName := "SERIAL123" + deviceParamsPath := instancePath + `\` + serialKeyName + `\Device Parameters` + + mrh.addMockKey(enumUSBPath, &mockRegistryKey{subKeyNamesToReturn: []string{deviceInstanceID}}) + mrh.addMockKey(instancePath, &mockRegistryKey{subKeyNamesToReturn: []string{serialKeyName}}) + mrh.addMockKey(deviceParamsPath, &mockRegistryKey{stringValueToReturn: "COM3"}) + }, + portChecker: mockPortChecker(true), + expected: []SerialDeviceInfo{ + {Vid: "0403", Pid: "6001", SerialNumber: "SERIAL123", Port: "COM3"}, + }, + }, + { + name: "Single device, VID filter mismatch", + vidFilter: "FFFF", pidFilter: "6001", + setupMock: func(mrh *mockRegistryHandler) { + deviceInstanceID := "VID_0403&PID_6001" + mrh.addMockKey(enumUSBPath, &mockRegistryKey{subKeyNamesToReturn: []string{deviceInstanceID}}) + // No need to mock further as it won't be reached + }, + portChecker: mockPortChecker(true), + expected: []SerialDeviceInfo{}, + }, + { + name: "Single device, port inactive", + vidFilter: "0403", pidFilter: "6001", + setupMock: func(mrh *mockRegistryHandler) { + deviceInstanceID := "VID_0403&PID_6001" + instancePath := enumUSBPath + `\` + deviceInstanceID + serialKeyName := "SERIAL123" + deviceParamsPath := instancePath + `\` + serialKeyName + `\Device Parameters` + + mrh.addMockKey(enumUSBPath, &mockRegistryKey{subKeyNamesToReturn: []string{deviceInstanceID}}) + mrh.addMockKey(instancePath, &mockRegistryKey{subKeyNamesToReturn: []string{serialKeyName}}) + mrh.addMockKey(deviceParamsPath, &mockRegistryKey{stringValueToReturn: "COM3"}) + }, + portChecker: mockPortChecker(false), // Port is not active + expected: []SerialDeviceInfo{}, + }, + { + name: "Device missing PortName string value", + vidFilter: "0403", pidFilter: "6001", + setupMock: func(mrh *mockRegistryHandler) { + deviceInstanceID := "VID_0403&PID_6001" + instancePath := enumUSBPath + `\` + deviceInstanceID + serialKeyName := "SERIAL123" + deviceParamsPath := instancePath + `\` + serialKeyName + `\Device Parameters` + + mrh.addMockKey(enumUSBPath, &mockRegistryKey{subKeyNamesToReturn: []string{deviceInstanceID}}) + mrh.addMockKey(instancePath, &mockRegistryKey{subKeyNamesToReturn: []string{serialKeyName}}) + mrh.addMockKey(deviceParamsPath, &mockRegistryKey{stringValueError: errors.New("value not found")}) + }, + portChecker: mockPortChecker(true), + expected: []SerialDeviceInfo{}, + }, + { + name: "OpenKey error for Device Parameters", + vidFilter: "0403", pidFilter: "6001", + setupMock: func(mrh *mockRegistryHandler) { + deviceInstanceID := "VID_0403&PID_6001" + instancePath := enumUSBPath + `\` + deviceInstanceID + serialKeyName := "SERIAL123" + deviceParamsPath := instancePath + `\` + serialKeyName + `\Device Parameters` + + mrh.addMockKey(enumUSBPath, &mockRegistryKey{subKeyNamesToReturn: []string{deviceInstanceID}}) + mrh.addMockKey(instancePath, &mockRegistryKey{subKeyNamesToReturn: []string{serialKeyName}}) + mrh.setOpenKeyError(deviceParamsPath, errors.New("cannot open device params")) + }, + portChecker: mockPortChecker(true), + expected: []SerialDeviceInfo{}, + }, + { + name: "VID/PID filter case insensitivity", + vidFilter: "0a1b", pidFilter: "0c2d", // Filter with lowercase + setupMock: func(mrh *mockRegistryHandler) { + deviceInstanceID := "VID_0A1B&PID_0C2D" // Registry has uppercase + instancePath := enumUSBPath + `\` + deviceInstanceID + serialKeyName := "SERIALXYZ" + deviceParamsPath := instancePath + `\` + serialKeyName + `\Device Parameters` + + mrh.addMockKey(enumUSBPath, &mockRegistryKey{subKeyNamesToReturn: []string{deviceInstanceID}}) + mrh.addMockKey(instancePath, &mockRegistryKey{subKeyNamesToReturn: []string{serialKeyName}}) + mrh.addMockKey(deviceParamsPath, &mockRegistryKey{stringValueToReturn: "COM4"}) + }, + portChecker: mockPortChecker(true), + expected: []SerialDeviceInfo{ + {Vid: "0A1B", Pid: "0C2D", SerialNumber: "SERIALXYZ", Port: "COM4"}, + }, + }, + { + name: "Multiple devices, one active, one inactive, one no portname", + setupMock: func(mrh *mockRegistryHandler) { + // Device 1: Active + dev1ID := "VID_AAAA&PID_1111" + dev1InstancePath := enumUSBPath + `\` + dev1ID + dev1Serial := "SER_ACTIVE" + dev1ParamsPath := dev1InstancePath + `\` + dev1Serial + `\Device Parameters` + mrh.addMockKey(dev1InstancePath, &mockRegistryKey{subKeyNamesToReturn: []string{dev1Serial}}) + mrh.addMockKey(dev1ParamsPath, &mockRegistryKey{stringValueToReturn: "COM10"}) + + // Device 2: Inactive port + dev2ID := "VID_BBBB&PID_2222" + dev2InstancePath := enumUSBPath + `\` + dev2ID + dev2Serial := "SER_INACTIVE" + dev2ParamsPath := dev2InstancePath + `\` + dev2Serial + `\Device Parameters` + mrh.addMockKey(dev2InstancePath, &mockRegistryKey{subKeyNamesToReturn: []string{dev2Serial}}) + mrh.addMockKey(dev2ParamsPath, &mockRegistryKey{stringValueToReturn: "COM11"}) + + // Device 3: No PortName + dev3ID := "VID_CCCC&PID_3333" + dev3InstancePath := enumUSBPath + `\` + dev3ID + dev3Serial := "SER_NOPORT" + dev3ParamsPath := dev3InstancePath + `\` + dev3Serial + `\Device Parameters` + mrh.addMockKey(dev3InstancePath, &mockRegistryKey{subKeyNamesToReturn: []string{dev3Serial}}) + mrh.addMockKey(dev3ParamsPath, &mockRegistryKey{stringValueError: errors.New("no portname")}) + + mrh.addMockKey(enumUSBPath, &mockRegistryKey{subKeyNamesToReturn: []string{dev1ID, dev2ID, dev3ID}}) + }, + portChecker: func(portName string) bool { + return portName == "COM10" // Only COM10 is active + }, + expected: []SerialDeviceInfo{ + {Vid: "AAAA", Pid: "1111", SerialNumber: "SER_ACTIVE", Port: "COM10"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mrh := newMockRegistryHandler() + tt.setupMock(mrh) + + // Store original checkPortActive and defer its restoration + originalCheckPortActive := checkPortActive + checkPortActive = tt.portChecker + defer func() { checkPortActive = originalCheckPortActive }() + + + devices, err := getSerialDevicesWithRegistry(tt.vidFilter, tt.pidFilter, mrh, tt.portChecker) + + if (err != nil) != tt.wantErr { + t.Fatalf("getSerialDevicesWithRegistry() error = %v, wantErr %v", err, tt.wantErr) + } + + if !tt.wantErr { + if len(devices) == 0 && len(tt.expected) == 0 { + // Both are empty, consider it a match. + } else if !reflect.DeepEqual(devices, tt.expected) { + t.Errorf("getSerialDevicesWithRegistry() got = %+v, want %+v", devices, tt.expected) + } + } + }) + } +}