Skip to content

Commit 5c094b6

Browse files
committed
Implement debug register abstraction
Signed-off-by: Ludvig Liljenberg <4257730+ludfjig@users.noreply.github.com>
1 parent 89ce6eb commit 5c094b6

File tree

1 file changed

+269
-0
lines changed

1 file changed

+269
-0
lines changed
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
/*
2+
Copyright 2025 The Hyperlight Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
#[cfg(kvm)]
18+
use kvm_bindings::kvm_debugregs;
19+
#[cfg(mshv3)]
20+
use mshv_bindings::DebugRegisters;
21+
22+
#[derive(Debug, Default, Copy, Clone, PartialEq)]
23+
pub(crate) struct CommonDebugRegs {
24+
pub dr0: u64,
25+
pub dr1: u64,
26+
pub dr2: u64,
27+
pub dr3: u64,
28+
pub dr6: u64,
29+
pub dr7: u64,
30+
}
31+
32+
#[cfg(kvm)]
33+
impl From<kvm_debugregs> for CommonDebugRegs {
34+
fn from(kvm_regs: kvm_debugregs) -> Self {
35+
Self {
36+
dr0: kvm_regs.db[0],
37+
dr1: kvm_regs.db[1],
38+
dr2: kvm_regs.db[2],
39+
dr3: kvm_regs.db[3],
40+
dr6: kvm_regs.dr6,
41+
dr7: kvm_regs.dr7,
42+
}
43+
}
44+
}
45+
#[cfg(kvm)]
46+
impl From<&CommonDebugRegs> for kvm_debugregs {
47+
fn from(common_regs: &CommonDebugRegs) -> Self {
48+
kvm_debugregs {
49+
db: [
50+
common_regs.dr0,
51+
common_regs.dr1,
52+
common_regs.dr2,
53+
common_regs.dr3,
54+
],
55+
dr6: common_regs.dr6,
56+
dr7: common_regs.dr7,
57+
..Default::default()
58+
}
59+
}
60+
}
61+
#[cfg(mshv3)]
62+
impl From<DebugRegisters> for CommonDebugRegs {
63+
fn from(mshv_regs: DebugRegisters) -> Self {
64+
Self {
65+
dr0: mshv_regs.dr0,
66+
dr1: mshv_regs.dr1,
67+
dr2: mshv_regs.dr2,
68+
dr3: mshv_regs.dr3,
69+
dr6: mshv_regs.dr6,
70+
dr7: mshv_regs.dr7,
71+
}
72+
}
73+
}
74+
#[cfg(mshv3)]
75+
impl From<&CommonDebugRegs> for DebugRegisters {
76+
fn from(common_regs: &CommonDebugRegs) -> Self {
77+
DebugRegisters {
78+
dr0: common_regs.dr0,
79+
dr1: common_regs.dr1,
80+
dr2: common_regs.dr2,
81+
dr3: common_regs.dr3,
82+
dr6: common_regs.dr6,
83+
dr7: common_regs.dr7,
84+
}
85+
}
86+
}
87+
88+
#[cfg(target_os = "windows")]
89+
use windows::Win32::System::Hypervisor::*;
90+
91+
#[cfg(target_os = "windows")]
92+
impl From<&CommonDebugRegs>
93+
for [(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>); WHP_DEBUG_REGS_NAMES_LEN]
94+
{
95+
fn from(regs: &CommonDebugRegs) -> Self {
96+
[
97+
(
98+
WHvX64RegisterDr0,
99+
Align16(WHV_REGISTER_VALUE { Reg64: regs.dr0 }),
100+
),
101+
(
102+
WHvX64RegisterDr1,
103+
Align16(WHV_REGISTER_VALUE { Reg64: regs.dr1 }),
104+
),
105+
(
106+
WHvX64RegisterDr2,
107+
Align16(WHV_REGISTER_VALUE { Reg64: regs.dr2 }),
108+
),
109+
(
110+
WHvX64RegisterDr3,
111+
Align16(WHV_REGISTER_VALUE { Reg64: regs.dr3 }),
112+
),
113+
(
114+
WHvX64RegisterDr6,
115+
Align16(WHV_REGISTER_VALUE { Reg64: regs.dr6 }),
116+
),
117+
(
118+
WHvX64RegisterDr7,
119+
Align16(WHV_REGISTER_VALUE { Reg64: regs.dr7 }),
120+
),
121+
]
122+
}
123+
}
124+
125+
#[cfg(target_os = "windows")]
126+
use std::collections::HashSet;
127+
128+
#[cfg(target_os = "windows")]
129+
use super::{Align16, FromWhpRegisterError};
130+
131+
#[cfg(target_os = "windows")]
132+
pub(crate) const WHP_DEBUG_REGS_NAMES_LEN: usize = 6;
133+
#[cfg(target_os = "windows")]
134+
pub(crate) const WHP_DEBUG_REGS_NAMES: [WHV_REGISTER_NAME; WHP_DEBUG_REGS_NAMES_LEN] = [
135+
WHvX64RegisterDr0,
136+
WHvX64RegisterDr1,
137+
WHvX64RegisterDr2,
138+
WHvX64RegisterDr3,
139+
WHvX64RegisterDr6,
140+
WHvX64RegisterDr7,
141+
];
142+
143+
#[cfg(target_os = "windows")]
144+
impl TryFrom<&[(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>)]> for CommonDebugRegs {
145+
type Error = FromWhpRegisterError;
146+
147+
#[expect(
148+
non_upper_case_globals,
149+
reason = "Windows API has lowercase register names"
150+
)]
151+
fn try_from(
152+
regs: &[(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>)],
153+
) -> Result<Self, Self::Error> {
154+
if regs.len() != WHP_DEBUG_REGS_NAMES_LEN {
155+
return Err(FromWhpRegisterError::InvalidLength(regs.len()));
156+
}
157+
let mut registers = CommonDebugRegs::default();
158+
let mut seen_registers = HashSet::new();
159+
160+
for &(name, value) in regs {
161+
let name_id = name.0;
162+
163+
// Check for duplicates
164+
if !seen_registers.insert(name_id) {
165+
return Err(FromWhpRegisterError::DuplicateRegister(name_id));
166+
}
167+
168+
unsafe {
169+
match name {
170+
WHvX64RegisterDr0 => registers.dr0 = value.0.Reg64,
171+
WHvX64RegisterDr1 => registers.dr1 = value.0.Reg64,
172+
WHvX64RegisterDr2 => registers.dr2 = value.0.Reg64,
173+
WHvX64RegisterDr3 => registers.dr3 = value.0.Reg64,
174+
WHvX64RegisterDr6 => registers.dr6 = value.0.Reg64,
175+
WHvX64RegisterDr7 => registers.dr7 = value.0.Reg64,
176+
_ => {
177+
// Given unexpected register
178+
return Err(FromWhpRegisterError::InvalidRegister(name_id));
179+
}
180+
}
181+
}
182+
}
183+
184+
// Set of all expected register names
185+
let expected_registers: HashSet<i32> = WHP_DEBUG_REGS_NAMES
186+
.map(|name| name.0)
187+
.into_iter()
188+
.collect();
189+
190+
// Technically it should not be possible to have any missing registers at this point
191+
// since we are guaranteed to have WHP_DEBUG_REGS_NAMES_LEN (6) non-duplicate registers that have passed the match-arm above, but leaving this here for safety anyway
192+
let missing: HashSet<_> = expected_registers
193+
.difference(&seen_registers)
194+
.cloned()
195+
.collect();
196+
197+
if !missing.is_empty() {
198+
return Err(FromWhpRegisterError::MissingRegister(missing));
199+
}
200+
201+
Ok(registers)
202+
}
203+
}
204+
205+
#[cfg(test)]
206+
mod tests {
207+
use super::*;
208+
209+
fn common_debug_regs() -> CommonDebugRegs {
210+
CommonDebugRegs {
211+
dr0: 1,
212+
dr1: 2,
213+
dr2: 3,
214+
dr3: 4,
215+
dr6: 5,
216+
dr7: 6,
217+
}
218+
}
219+
220+
#[cfg(kvm)]
221+
#[test]
222+
fn round_trip_kvm_debug_regs() {
223+
let original = common_debug_regs();
224+
let kvm_regs: kvm_debugregs = (&original).into();
225+
let converted: CommonDebugRegs = kvm_regs.into();
226+
assert_eq!(original, converted);
227+
}
228+
229+
#[cfg(mshv3)]
230+
#[test]
231+
fn round_trip_mshv_debug_regs() {
232+
let original = common_debug_regs();
233+
let mshv_regs: DebugRegisters = (&original).into();
234+
let converted: CommonDebugRegs = mshv_regs.into();
235+
assert_eq!(original, converted);
236+
}
237+
238+
#[cfg(target_os = "windows")]
239+
#[test]
240+
fn round_trip_whp_debug_regs() {
241+
let original = common_debug_regs();
242+
let whp_regs: [(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>); WHP_DEBUG_REGS_NAMES_LEN] =
243+
(&original).into();
244+
let converted: CommonDebugRegs = whp_regs.as_ref().try_into().unwrap();
245+
assert_eq!(original, converted);
246+
247+
// test for duplicate register error handling
248+
let original = common_debug_regs();
249+
let mut whp_regs: [(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>);
250+
WHP_DEBUG_REGS_NAMES_LEN] = (&original).into();
251+
whp_regs[0].0 = WHvX64RegisterDr1;
252+
let err = CommonDebugRegs::try_from(whp_regs.as_ref()).unwrap_err();
253+
assert_eq!(
254+
err,
255+
FromWhpRegisterError::DuplicateRegister(WHvX64RegisterDr1.0)
256+
);
257+
258+
// test for passing non-standard register (e.g. CR8)
259+
let original = common_debug_regs();
260+
let mut whp_regs: [(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>);
261+
WHP_DEBUG_REGS_NAMES_LEN] = (&original).into();
262+
whp_regs[0].0 = WHvX64RegisterCr8;
263+
let err = CommonDebugRegs::try_from(whp_regs.as_ref()).unwrap_err();
264+
assert_eq!(
265+
err,
266+
FromWhpRegisterError::InvalidRegister(WHvX64RegisterCr8.0)
267+
);
268+
}
269+
}

0 commit comments

Comments
 (0)