diff --git a/pyproject.toml b/pyproject.toml index fcd9474fda63..7e193455dbab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -205,6 +205,24 @@ reportUnusedClass = "none" reportUnnecessaryCast = "none" # mypy already checks this. If it fails for pyright its because mypy requires it reportUnnecessaryContains = "none" +[tool.ty] + +[tool.ty.rules] + +# we catch these via pyright or mypy so ignore here +# deprecated: we trigger deprecated on use of deprecated methods in other deprecated methods. +# unresolved-imports: we have a lot of imports that are only available with various libraries for specific instrument drivers +# unused-ignore-comment: mypy already checks for unused-ignores so it they are unused by ty its because mypy requires them +unresolved-import = "ignore" +deprecated = "ignore" +unused-type-ignore-comment = "ignore" + +[[tool.ty.overrides]] +include = ["src/qcodes/instrument_drivers/Harvard/Decadac.py"] + +[tool.ty.overrides.rules] +unresolved-attribute = "ignore" + [tool.pytest.ini_options] minversion = "7.2" testpaths = "tests" diff --git a/src/qcodes/dataset/dond/do_nd.py b/src/qcodes/dataset/dond/do_nd.py index c3b6a304f178..d40d27a11364 100644 --- a/src/qcodes/dataset/dond/do_nd.py +++ b/src/qcodes/dataset/dond/do_nd.py @@ -23,6 +23,7 @@ _set_write_period, catch_interrupts, ) +from qcodes.dataset.experiment_container import Experiment from qcodes.dataset.measurements import Measurement from qcodes.dataset.threading import ( SequentialParamsCaller, @@ -42,7 +43,7 @@ MultiAxesTupleListWithDataSet, ParamMeasT, ) - from qcodes.dataset.experiment_container import Experiment + LOG = logging.getLogger(__name__) SweepVarType = Any @@ -400,8 +401,12 @@ def _get_experiments( experiments_internal: Sequence[Experiment | None] = [ experiments ] * n_experiments_required - else: + elif not isinstance(experiments, Experiment): experiments_internal = experiments + else: + raise TypeError( + f"Invalid type for experiments got {experiments} of type {type(experiments)}" + ) if len(experiments_internal) != n_experiments_required: raise ValueError( diff --git a/src/qcodes/extensions/infer.py b/src/qcodes/extensions/infer.py index d22b4060fd6c..2e596cec6f5f 100644 --- a/src/qcodes/extensions/infer.py +++ b/src/qcodes/extensions/infer.py @@ -264,8 +264,10 @@ def get_parent_instruments_from_chain_of_type( param_chain = get_parameter_chain(parameter) return tuple( + # cast is required since mypy as of 1.19.1 cannot infer the type narrowing based + # on isinstance checks inside comprehensions [ - cast("TInstrument", param.instrument) + cast("TInstrument", param.instrument) # ty: ignore[redundant-cast] for param in param_chain if isinstance(param.instrument, instrument_type) ] diff --git a/src/qcodes/instrument/channel.py b/src/qcodes/instrument/channel.py index 3b774f1dc134..f6132a0a7364 100644 --- a/src/qcodes/instrument/channel.py +++ b/src/qcodes/instrument/channel.py @@ -198,42 +198,42 @@ def __init__( ) @overload - def __getitem__(self, i: int) -> InstrumentModuleType: ... + def __getitem__(self, index: int) -> InstrumentModuleType: ... @overload - def __getitem__(self: Self, i: slice | tuple[int, ...]) -> Self: ... + def __getitem__(self: Self, index: slice | tuple[int, ...]) -> Self: ... def __getitem__( - self: Self, i: int | slice | tuple[int, ...] + self: Self, index: int | slice | tuple[int, ...] ) -> InstrumentModuleType | Self: """ Return either a single channel, or a new :class:`ChannelTuple` containing only the specified channels Args: - i: Either a single channel index or a slice of channels + index: Either a single channel index or a slice of channels to get """ - if isinstance(i, slice): + if isinstance(index, slice): return type(self)( self._parent, self._name, self._chan_type, - self._channels[i], + self._channels[index], multichan_paramclass=self._paramclass, snapshotable=self._snapshotable, ) - elif isinstance(i, tuple): + elif isinstance(index, tuple): return type(self)( self._parent, self._name, self._chan_type, - [self._channels[j] for j in i], + [self._channels[j] for j in index], multichan_paramclass=self._paramclass, snapshotable=self._snapshotable, ) - return self._channels[i] + return self._channels[index] def __iter__(self) -> Iterator[InstrumentModuleType]: return iter(self._channels) @@ -244,8 +244,8 @@ def __reversed__(self) -> Iterator[InstrumentModuleType]: def __len__(self) -> int: return len(self._channels) - def __contains__(self, item: object) -> bool: - return item in self._channels + def __contains__(self, value: object) -> bool: + return value in self._channels def __repr__(self) -> str: return ( @@ -315,11 +315,9 @@ def name_parts(self) -> list[str]: name_parts.append(self.short_name) return name_parts - # the parameter obj should be called value but that would - # be an incompatible change - def index( # pyright: ignore[reportIncompatibleMethodOverride] + def index( self, - obj: InstrumentModuleType, + value: InstrumentModuleType, start: int = 0, stop: int = sys.maxsize, ) -> int: @@ -327,23 +325,21 @@ def index( # pyright: ignore[reportIncompatibleMethodOverride] Return the index of the given object Args: - obj: The object to find in the channel list. + value: The object to find in the channel list. start: Index to start searching from. stop: Index to stop searching at. """ - return self._channels.index(obj, start, stop) + return self._channels.index(value, start, stop) - def count( # pyright: ignore[reportIncompatibleMethodOverride] - self, obj: InstrumentModuleType - ) -> int: + def count(self, value: InstrumentModuleType) -> int: """Returns number of instances of the given object in the list Args: - obj: The object to find in the ChannelTuple. + value: The object to find in the ChannelTuple. """ - return self._channels.count(obj) + return self._channels.count(value) def get_channels_by_name(self: Self, *names: str) -> Self: """ @@ -717,15 +713,15 @@ def __init__( self._locked = False @overload - def __delitem__(self, key: int) -> None: ... + def __delitem__(self, index: int) -> None: ... @overload - def __delitem__(self, key: slice) -> None: ... + def __delitem__(self, index: slice) -> None: ... - def __delitem__(self, key: int | slice) -> None: + def __delitem__(self, index: int | slice) -> None: if self._locked: raise AttributeError("Cannot delete from a locked channel list") - self._channels.__delitem__(key) + self._channels.__delitem__(index) self._channel_mapping = { channel.short_name: channel for channel in self._channels } @@ -759,27 +755,25 @@ def __setitem__( channel.short_name: channel for channel in self._channels } - def append( # pyright: ignore[reportIncompatibleMethodOverride] - self, obj: InstrumentModuleType - ) -> None: + def append(self, value: InstrumentModuleType) -> None: """ Append a Channel to this list. Requires that the ChannelList is not locked and that the channel is of the same type as the ones in the list. Args: - obj: New channel to add to the list. + value: New channel to add to the list. """ if self._locked: raise AttributeError("Cannot append to a locked channel list") - if not isinstance(obj, self._chan_type): + if not isinstance(value, self._chan_type): raise TypeError( f"All items in a channel list must be of the same " - f"type. Adding {type(obj).__name__} to a " + f"type. Adding {type(value).__name__} to a " f"list of {self._chan_type.__name__}." ) - self._channel_mapping[obj.short_name] = obj - self._channels.append(obj) + self._channel_mapping[value.short_name] = value + self._channels.append(value) def clear(self) -> None: """ @@ -791,63 +785,59 @@ def clear(self) -> None: self._channels.clear() self._channel_mapping.clear() - def remove( # pyright: ignore[reportIncompatibleMethodOverride] - self, obj: InstrumentModuleType - ) -> None: + def remove(self, value: InstrumentModuleType) -> None: """ - Removes obj from ChannelList if not locked. + Removes value from ChannelList if not locked. Args: - obj: Channel to remove from the list. + value: Channel to remove from the list. """ if self._locked: raise AttributeError("Cannot remove from a locked channel list") else: - self._channels.remove(obj) - self._channel_mapping.pop(obj.short_name) + self._channels.remove(value) + self._channel_mapping.pop(value.short_name) - def extend( # pyright: ignore[reportIncompatibleMethodOverride] - self, objects: Iterable[InstrumentModuleType] - ) -> None: + def extend(self, values: Iterable[InstrumentModuleType]) -> None: """ - Insert an iterable of objects into the list of channels. + Insert an iterable of InstrumentModules into the list of channels. Args: - objects: A list of objects to add into the + values: A list of InstrumentModules to add into the :class:`ChannelList`. """ - # objects may be a generator but we need to iterate over it twice + # values may be a generator but we need to iterate over it twice # below so copy it into a tuple just in case. if self._locked: raise AttributeError("Cannot extend a locked channel list") - objects_tuple = tuple(objects) - if not all(isinstance(obj, self._chan_type) for obj in objects_tuple): + values_tuple = tuple(values) + if not all(isinstance(value, self._chan_type) for value in values_tuple): raise TypeError("All items in a channel list must be of the same type.") - self._channels.extend(objects_tuple) - self._channel_mapping.update({obj.short_name: obj for obj in objects_tuple}) + self._channels.extend(values_tuple) + self._channel_mapping.update( + {value.short_name: value for value in values_tuple} + ) - def insert( # pyright: ignore[reportIncompatibleMethodOverride] - self, index: int, obj: InstrumentModuleType - ) -> None: + def insert(self, index: int, value: InstrumentModuleType) -> None: """ Insert an object into the ChannelList at a specific index. Args: index: Index to insert object. - obj: Object of type chan_type to insert. + value: Object of type chan_type to insert. """ if self._locked: raise AttributeError("Cannot insert into a locked channel list") - if not isinstance(obj, self._chan_type): + if not isinstance(value, self._chan_type): raise TypeError( f"All items in a channel list must be of the same " - f"type. Adding {type(obj).__name__} to a list of {self._chan_type.__name__}." + f"type. Adding {type(value).__name__} to a list of {self._chan_type.__name__}." ) - self._channels.insert(index, obj) - self._channel_mapping[obj.short_name] = obj + self._channels.insert(index, value) + self._channel_mapping[value.short_name] = value def get_validator(self) -> ChannelTupleValidator: """ diff --git a/src/qcodes/instrument/mockers/ami430.py b/src/qcodes/instrument/mockers/ami430.py index 329c3cbc616c..c02b1bbc68a0 100644 --- a/src/qcodes/instrument/mockers/ami430.py +++ b/src/qcodes/instrument/mockers/ami430.py @@ -163,7 +163,7 @@ def _handle_messages(self, msg): if callable(handler): # some of the callables in the dict does not take arguments. # ignore that warning for now since this is mock code only - rval = handler(args) # pyright: ignore[reportCallIssue] + rval = handler(args) # pyright: ignore[reportCallIssue] # ty: ignore[ too-many-positional-arguments] else: rval = handler diff --git a/src/qcodes/instrument_drivers/AlazarTech/dll_wrapper.py b/src/qcodes/instrument_drivers/AlazarTech/dll_wrapper.py index 15ef43685039..59a131825da3 100644 --- a/src/qcodes/instrument_drivers/AlazarTech/dll_wrapper.py +++ b/src/qcodes/instrument_drivers/AlazarTech/dll_wrapper.py @@ -62,6 +62,7 @@ def _mark_params_as_updated(*args: Any) -> None: def _check_error_code( return_code: int, func: Callable[..., Any], arguments: tuple[Any, ...] ) -> tuple[Any, ...]: + func_name: str = getattr(func, "__name__", "UnknownFunction") if return_code not in {API_SUCCESS, API_DMA_IN_PROGRESS}: argrepr = repr(arguments) if len(argrepr) > 100: @@ -69,15 +70,15 @@ def _check_error_code( logger.error( f"Alazar API returned code {return_code} from function " - f"{func.__name__} with args {argrepr}" + f"{func_name} with args {argrepr}" ) if return_code not in ERROR_CODES: raise RuntimeError( - f"unknown error {return_code} from function {func.__name__} with args: {argrepr}" + f"unknown error {return_code} from function {func_name} with args: {argrepr}" ) raise RuntimeError( - f"error {return_code}: {ERROR_CODES[ReturnCode(return_code)]} from function {func.__name__} with args: {argrepr}" + f"error {return_code}: {ERROR_CODES[ReturnCode(return_code)]} from function {func_name} with args: {argrepr}" ) return arguments diff --git a/src/qcodes/instrument_drivers/Galil/dmc_41x3.py b/src/qcodes/instrument_drivers/Galil/dmc_41x3.py index d50f443499af..d8b8631a718a 100644 --- a/src/qcodes/instrument_drivers/Galil/dmc_41x3.py +++ b/src/qcodes/instrument_drivers/Galil/dmc_41x3.py @@ -254,7 +254,7 @@ def clear_sequence(self, coord_sys: str) -> None: """ -class GalilDMC4133Motor(InstrumentChannel): +class GalilDMC4133Motor(InstrumentChannel["GalilDMC4133Controller"]): """ Class to control a single motor (independent of possible other motors) """ @@ -458,7 +458,7 @@ def wait_till_motor_motion_complete(self) -> None: while self.is_in_motion(): pass except KeyboardInterrupt: - self.root_instrument.abort() + self.parent.abort() self.off() def error_magnitude(self) -> float: diff --git a/src/qcodes/instrument_drivers/Keithley/Keithley_2000.py b/src/qcodes/instrument_drivers/Keithley/Keithley_2000.py index ca478f5a0c43..3d30cc36b37a 100644 --- a/src/qcodes/instrument_drivers/Keithley/Keithley_2000.py +++ b/src/qcodes/instrument_drivers/Keithley/Keithley_2000.py @@ -209,8 +209,6 @@ def __init__( ) """Parameter amplitude""" - self.add_function("reset", call_cmd="*RST") - if reset: self.reset() @@ -220,6 +218,9 @@ def __init__( self.connect_message() + def reset(self) -> None: + self.write("*RST") + def trigger(self) -> None: if not self.trigger_continuous(): self.write("INIT") diff --git a/src/qcodes/instrument_drivers/Keithley/Keithley_2450.py b/src/qcodes/instrument_drivers/Keithley/Keithley_2450.py index cfe6072a17bd..2ce7877968a5 100644 --- a/src/qcodes/instrument_drivers/Keithley/Keithley_2450.py +++ b/src/qcodes/instrument_drivers/Keithley/Keithley_2450.py @@ -45,7 +45,7 @@ def get_selected(self) -> list[Any] | None: return self._user_selected_data -class Keithley2450Buffer(InstrumentChannel): +class Keithley2450Buffer(InstrumentChannel["Keithley2450"]): """ Treat the reading buffer as a submodule, similar to Sense and Source """ @@ -379,7 +379,7 @@ def _set_user_delay(self, value: float) -> None: self.write(set_cmd) -class Keithley2450Source(InstrumentChannel): +class Keithley2450Source(InstrumentChannel["Keithley2450"]): """ The source module of the Keithley 2450 SMU. diff --git a/src/qcodes/instrument_drivers/Keithley/Keithley_7510.py b/src/qcodes/instrument_drivers/Keithley/Keithley_7510.py index 349f1eb7a6ae..b678f697058c 100644 --- a/src/qcodes/instrument_drivers/Keithley/Keithley_7510.py +++ b/src/qcodes/instrument_drivers/Keithley/Keithley_7510.py @@ -84,7 +84,7 @@ def get_raw(self) -> npt.NDArray: return np.linspace(start, stop, n_points) -class Keithley7510Buffer(InstrumentChannel): +class Keithley7510Buffer(InstrumentChannel["Keithley7510"]): """ Treat the reading buffer as a submodule, similar to Sense. """ @@ -419,7 +419,7 @@ class _FunctionMode(TypedDict): range_vals: Numbers | None -class Keithley7510Sense(InstrumentChannel): +class Keithley7510Sense(InstrumentChannel["Keithley7510"]): function_modes: ClassVar[dict[str, _FunctionMode]] = { "voltage": { "name": '"VOLT:DC"', @@ -455,7 +455,7 @@ class Keithley7510Sense(InstrumentChannel): def __init__( self, - parent: VisaInstrument, + parent: "Keithley7510", name: str, proper_function: str, **kwargs: "Unpack[InstrumentBaseKWArgs]", @@ -630,7 +630,7 @@ def clear_trace(self, buffer_name: str = "defbuffer1") -> None: self.write(f":TRACe:CLEar '{buffer_name}'") -class Keithley7510DigitizeSense(InstrumentChannel): +class Keithley7510DigitizeSense(InstrumentChannel["Keithley7510"]): """ The Digitize sense module of the Keithley 7510 DMM. """ @@ -649,7 +649,7 @@ class Keithley7510DigitizeSense(InstrumentChannel): }, } - def __init__(self, parent: VisaInstrument, name: str, proper_function: str) -> None: + def __init__(self, parent: "Keithley7510", name: str, proper_function: str) -> None: super().__init__(parent, name) self._proper_function = proper_function diff --git a/src/qcodes/instrument_drivers/Keithley/Keithley_s46.py b/src/qcodes/instrument_drivers/Keithley/Keithley_s46.py index 6ee7238bf131..11b7901dc3a4 100644 --- a/src/qcodes/instrument_drivers/Keithley/Keithley_s46.py +++ b/src/qcodes/instrument_drivers/Keithley/Keithley_s46.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, ClassVar from qcodes.instrument import ( - Instrument, VisaInstrument, VisaInstrumentKWArgs, ) @@ -56,7 +55,7 @@ def release(self, channel_number: int) -> None: self._locked_by = None -class S46Parameter(Parameter): +class S46Parameter(Parameter[ParamRawDataType, "KeithleyS46"]): """ A parameter class for S46 channels. We do not use the QCoDeS InstrumentChannel class because our channel has one state parameter, @@ -73,7 +72,7 @@ class S46Parameter(Parameter): def __init__( self, name: str, - instrument: Instrument | None, + instrument: "KeithleyS46", channel_number: int, lock: KeithleyS46RelayLock, **kwargs: Any, @@ -95,7 +94,6 @@ def __init__( ) from e def _get(self, get_cached: bool) -> str: - assert isinstance(self.instrument, KeithleyS46) closed_channels = self.instrument.closed_channels.get_latest() if not get_cached or closed_channels is None: diff --git a/src/qcodes/instrument_drivers/Keithley/_Keithley_2600.py b/src/qcodes/instrument_drivers/Keithley/_Keithley_2600.py index fcba7c7ae539..f52c9158908c 100644 --- a/src/qcodes/instrument_drivers/Keithley/_Keithley_2600.py +++ b/src/qcodes/instrument_drivers/Keithley/_Keithley_2600.py @@ -11,7 +11,6 @@ import qcodes.validators as vals from qcodes.instrument import ( - Instrument, InstrumentChannel, VisaInstrument, VisaInstrumentKWArgs, @@ -34,13 +33,15 @@ log = logging.getLogger(__name__) -class LuaSweepParameter(ArrayParameter): +class LuaSweepParameter(ArrayParameter[npt.NDArray, "Keithley2600Channel"]): """ Parameter class to hold the data from a deployed Lua script sweep. """ - def __init__(self, name: str, instrument: Instrument, **kwargs: Any) -> None: + def __init__( + self, name: str, instrument: Keithley2600Channel, **kwargs: Any + ) -> None: super().__init__( name=name, shape=(1,), @@ -49,7 +50,13 @@ def __init__(self, name: str, instrument: Instrument, **kwargs: Any) -> None: **kwargs, ) - def prepareSweep(self, start: float, stop: float, steps: int, mode: str) -> None: + def prepareSweep( + self, + start: float, + stop: float, + steps: int, + mode: Literal["IV", "VI", "VIfourprobe"], + ) -> None: """ Builds setpoints and labels @@ -63,31 +70,29 @@ def prepareSweep(self, start: float, stop: float, steps: int, mode: str) -> None """ - if mode not in ["IV", "VI", "VIfourprobe"]: - raise ValueError('mode must be either "VI", "IV" or "VIfourprobe"') - self.shape = (steps,) - if mode == "IV": - self.unit = "A" - self.setpoint_names = ("Voltage",) - self.setpoint_units = ("V",) - self.label = "current" - self._short_name = "iv_sweep" - - if mode == "VI": - self.unit = "V" - self.setpoint_names = ("Current",) - self.setpoint_units = ("A",) - self.label = "voltage" - self._short_name = "vi_sweep" - - if mode == "VIfourprobe": - self.unit = "V" - self.setpoint_names = ("Current",) - self.setpoint_units = ("A",) - self.label = "voltage" - self._short_name = "vi_sweep_four_probe" + match mode: + case "IV": + self.unit = "A" + self.setpoint_names = ("Voltage",) + self.setpoint_units = ("V",) + self.label = "current" + self._short_name = "iv_sweep" + case "VI": + self.unit = "V" + self.setpoint_names = ("Current",) + self.setpoint_units = ("A",) + self.label = "voltage" + self._short_name = "vi_sweep" + case "VIfourprobe": + self.unit = "V" + self.setpoint_names = ("Current",) + self.setpoint_units = ("A",) + self.label = "voltage" + self._short_name = "vi_sweep_four_probe" + case _: + raise ValueError('mode must be either "VI", "IV" or "VIfourprobe"') self.setpoints = (tuple(np.linspace(start, stop, steps)),) @@ -107,7 +112,7 @@ def get_raw(self) -> npt.NDArray: return data -class TimeTrace(ParameterWithSetpoints): +class TimeTrace(ParameterWithSetpoints[npt.NDArray, "Keithley2600Channel"]): """ A parameter class that holds the data corresponding to the time dependence of current and voltage. @@ -198,7 +203,7 @@ def get_raw(self) -> npt.NDArray: return data -class TimeAxis(Parameter): +class TimeAxis(Parameter[npt.NDArray, "Keithley2600Channel"]): """ A simple :class:`.Parameter` that holds all the times (relative to the measurement start) at which the points of the time trace were acquired. @@ -238,7 +243,7 @@ class Keithley2600MeasurementStatus(StrEnum): } -class _ParameterWithStatus(Parameter): +class _ParameterWithStatus(Parameter[ParamRawDataType, "Keithley2600Channel"]): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -278,9 +283,6 @@ def snapshot_base( class _MeasurementCurrentParameter(_ParameterWithStatus): def set_raw(self, value: ParamRawDataType) -> None: - assert isinstance(self.instrument, Keithley2600Channel) - assert isinstance(self.root_instrument, Keithley2600) - smu_chan = self.instrument channel = smu_chan.channel @@ -289,9 +291,6 @@ def set_raw(self, value: ParamRawDataType) -> None: smu_chan._reset_measurement_statuses_of_parameters() def get_raw(self) -> ParamRawDataType: - assert isinstance(self.instrument, Keithley2600Channel) - assert isinstance(self.root_instrument, Keithley2600) - smu = self.instrument channel = self.instrument.channel @@ -307,9 +306,6 @@ def get_raw(self) -> ParamRawDataType: class _MeasurementVoltageParameter(_ParameterWithStatus): def set_raw(self, value: ParamRawDataType) -> None: - assert isinstance(self.instrument, Keithley2600Channel) - assert isinstance(self.root_instrument, Keithley2600) - smu_chan = self.instrument channel = smu_chan.channel @@ -318,9 +314,6 @@ def set_raw(self, value: ParamRawDataType) -> None: smu_chan._reset_measurement_statuses_of_parameters() def get_raw(self) -> ParamRawDataType: - assert isinstance(self.instrument, Keithley2600Channel) - assert isinstance(self.root_instrument, Keithley2600) - smu = self.instrument channel = self.instrument.channel @@ -334,13 +327,13 @@ def get_raw(self) -> ParamRawDataType: return value -class Keithley2600Channel(InstrumentChannel): +class Keithley2600Channel(InstrumentChannel["Keithley2600"]): """ Class to hold the two Keithley channels, i.e. SMUA and SMUB. """ - def __init__(self, parent: Instrument, name: str, channel: str) -> None: + def __init__(self, parent: Keithley2600, name: str, channel: str) -> None: """ Args: parent: The Instrument instance to which the channel is @@ -355,12 +348,12 @@ def __init__(self, parent: Instrument, name: str, channel: str) -> None: raise ValueError('channel must be either "smub" or "smua"') super().__init__(parent, name) - self.model = self._parent.model + self.model = self.parent.model self._extra_visa_timeout = 5000 self._measurement_duration_factor = 2 # Ensures that we are always above # the expected time. - vranges = self._parent._vranges - iranges = self._parent._iranges + vranges = self.parent._vranges + iranges = self.parent._iranges vlimit_minmax = self.parent._vlimit_minmax ilimit_minmax = self.parent._ilimit_minmax @@ -630,9 +623,8 @@ def __init__(self, parent: Instrument, name: str, channel: str) -> None: self.channel = channel def _reset_measurement_statuses_of_parameters(self) -> None: - assert isinstance(self.volt, _ParameterWithStatus) self.volt._measurement_status = None - assert isinstance(self.curr, _ParameterWithStatus) + self.curr._measurement_status = None def reset(self) -> None: @@ -645,7 +637,13 @@ def reset(self) -> None: log.debug(f"Reset channel {self.channel}. Updating settings...") self.snapshot(update=True) - def doFastSweep(self, start: float, stop: float, steps: int, mode: str) -> DataSet: + def doFastSweep( + self, + start: float, + stop: float, + steps: int, + mode: Literal["IV", "VI", "VIfourprobe"], + ) -> DataSet: """ Perform a fast sweep using a deployed lua script and return a QCoDeS DataSet with the sweep. @@ -705,23 +703,24 @@ def _fast_sweep( dV = (stop - start) / (steps - 1) - if mode == "IV": - meas = "i" - sour = "v" - func = "1" - sense_mode = "0" - elif mode == "VI": - meas = "v" - sour = "i" - func = "0" - sense_mode = "0" - elif mode == "VIfourprobe": - meas = "v" - sour = "i" - func = "0" - sense_mode = "1" - else: - raise ValueError(f"Invalid mode {mode}") + match mode: + case "IV": + meas = "i" + sour = "v" + func = "1" + sense_mode = "0" + case "VI": + meas = "v" + sour = "i" + func = "0" + sense_mode = "0" + case "VIfourprobe": + meas = "v" + sour = "i" + func = "0" + sense_mode = "1" + case _: + raise ValueError(f"Invalid mode {mode}") script = [ f"{channel}.measure.nplc = {nplc:.12f}", @@ -765,7 +764,7 @@ def _execute_lua(self, _script: list[str], steps: int) -> npt.NDArray: estimated_measurement_duration + _time_trace_extra_visa_timeout ) - self.write(self.root_instrument._scriptwrapper(program=_script, debug=True)) + self.write(self.parent._scriptwrapper(program=_script, debug=True)) # now poll all the data # The problem is that a '\n' character might by chance be present in @@ -774,9 +773,9 @@ def _execute_lua(self, _script: list[str], steps: int) -> npt.NDArray: received = 0 data = b"" # we must wait for the script to execute - with self.root_instrument.timeout.set_to(new_visa_timeout): + with self.parent.timeout.set_to(new_visa_timeout): while received < fullsize: - data_temp = self.root_instrument.visa_handle.read_raw() + data_temp = self.parent.visa_handle.read_raw() received += len(data_temp) data += data_temp diff --git a/src/qcodes/instrument_drivers/Keysight/Infiniium.py b/src/qcodes/instrument_drivers/Keysight/Infiniium.py index 4f559ea09747..1c4310701ffc 100644 --- a/src/qcodes/instrument_drivers/Keysight/Infiniium.py +++ b/src/qcodes/instrument_drivers/Keysight/Infiniium.py @@ -13,6 +13,7 @@ import qcodes.validators as vals from qcodes.instrument import ( ChannelList, + ChannelTuple, InstrumentBase, InstrumentBaseKWArgs, InstrumentChannel, @@ -89,7 +90,11 @@ def get_raw(self) -> npt.NDArray: ) -class DSOTraceParam(ParameterWithSetpoints): +class DSOTraceParam( + ParameterWithSetpoints[ + npt.NDArray, "KeysightInfiniiumChannel | KeysightInfiniiumFunction" + ] +): """ Trace parameter for the Infiniium series DSO """ @@ -180,7 +185,7 @@ def update_setpoints(self, preamble: "Sequence[str] | None" = None) -> None: acquisition if instr.cache_setpoints is False """ instrument: KeysightInfiniiumChannel | KeysightInfiniiumFunction - instrument = self.instrument # type: ignore[assignment] + instrument = self.instrument if preamble is None: instrument.write(f":WAV:SOUR {self._channel}") preamble = instrument.ask(":WAV:PRE?").strip().split(",") @@ -257,6 +262,7 @@ def __init__( self, parent: InstrumentBase, name: str, + channel: str, **kwargs: "Unpack[InstrumentBaseKWArgs]", ) -> None: """ @@ -264,6 +270,7 @@ def __init__( directly, rather initialize BoundMeasurementSubsystem or UnboundMeasurementSubsystem. """ + self._channel = channel super().__init__(parent, name, **kwargs) ################################### @@ -477,11 +484,8 @@ def __init__( """ Initialize measurement subsystem bound to a specific channel """ - # Bind the channel - self._channel = parent.channel_name - # Initialize measurement parameters - super().__init__(parent, name, **kwargs) + super().__init__(parent, name, channel=parent.channel_name, **kwargs) BoundMeasurement = KeysightInfiniiumBoundMeasurement @@ -500,11 +504,8 @@ def __init__( """ Initialize measurement subsystem where target is set by the parameter `source`. """ - # Blank channel - self._channel = "" - # Initialize measurement parameters - super().__init__(parent, name, **kwargs) + super().__init__(parent, name, channel="", **kwargs) self.source = Parameter( name="source", @@ -515,6 +516,12 @@ def __init__( snapshot_value=False, ) + @property + def root_instrument(self) -> "KeysightInfiniium": + root_instrument = super().root_instrument + assert isinstance(root_instrument, KeysightInfiniium) + return root_instrument + def _validate_source(self, source: str) -> str: """Validate and set the source.""" valid_channels = f"CHAN[1-{self.root_instrument.no_channels}]" @@ -692,7 +699,7 @@ def _get_func(self) -> str: """ -class KeysightInfiniiumChannel(InstrumentChannel): +class KeysightInfiniiumChannel(InstrumentChannel["KeysightInfiniium"]): def __init__( self, parent: "KeysightInfiniium", @@ -1081,7 +1088,10 @@ def __init__( channel = KeysightInfiniiumChannel(self, f"chan{i}", i) _channels.append(channel) self.add_submodule(f"ch{i}", channel) - self.add_submodule("channels", _channels.to_channel_tuple()) + self.channels: ChannelTuple[KeysightInfiniiumChannel] = self.add_submodule( + "channels", _channels.to_channel_tuple() + ) + """Tuple of oscilloscope channels.""" # Functions _functions = ChannelList( @@ -1093,11 +1103,17 @@ def __init__( self.add_submodule(f"func{i}", function) # Have to call channel list "funcs" here as functions is a # reserved name in Instrument. - self.add_submodule("funcs", _functions.to_channel_tuple()) + self.funcs: ChannelTuple[KeysightInfiniiumFunction] = self.add_submodule( + "funcs", _functions.to_channel_tuple() + ) + """Tuple of oscilloscope functions.""" # Submodules meassubsys = KeysightInfiniiumUnboundMeasurement(self, "measure") - self.add_submodule("measure", meassubsys) + self.measure: KeysightInfiniiumUnboundMeasurement = self.add_submodule( + "measure", meassubsys + ) + """Unbound measurement subsystem.""" def _query_capabilities(self) -> None: """ diff --git a/src/qcodes/instrument_drivers/Keysight/KeysightAgilent_33XXX.py b/src/qcodes/instrument_drivers/Keysight/KeysightAgilent_33XXX.py index ae980f494cb3..1a6a7cf5b704 100644 --- a/src/qcodes/instrument_drivers/Keysight/KeysightAgilent_33XXX.py +++ b/src/qcodes/instrument_drivers/Keysight/KeysightAgilent_33XXX.py @@ -4,7 +4,6 @@ from qcodes import validators as vals from qcodes.instrument import ( - Instrument, InstrumentBaseKWArgs, InstrumentChannel, VisaInstrument, @@ -26,14 +25,14 @@ # 33200, 33500, and 33600 -class Keysight33xxxOutputChannel(InstrumentChannel): +class Keysight33xxxOutputChannel(InstrumentChannel["Keysight33xxx"]): """ Class to hold the output channel of a Keysight 33xxxx waveform generator. """ def __init__( self, - parent: Instrument, + parent: "Keysight33xxx", name: str, channum: int, **kwargs: "Unpack[InstrumentBaseKWArgs]", @@ -317,7 +316,7 @@ def val_parser(parser: type, inputstring: str) -> float | int: OutputChannel = Keysight33xxxOutputChannel -class Keysight33xxxSyncChannel(InstrumentChannel): +class Keysight33xxxSyncChannel(InstrumentChannel["Keysight33xxx"]): """ Class to hold the sync output of a Keysight 33xxxx waveform generator. Has very few parameters for single channel instruments. @@ -325,7 +324,7 @@ class Keysight33xxxSyncChannel(InstrumentChannel): def __init__( self, - parent: Instrument, + parent: "Keysight33xxx", name: str, **kwargs: "Unpack[InstrumentBaseKWArgs]", ): diff --git a/src/qcodes/instrument_drivers/Keysight/Keysight_N9030B.py b/src/qcodes/instrument_drivers/Keysight/Keysight_N9030B.py index 51244963d6b3..eed00159df49 100644 --- a/src/qcodes/instrument_drivers/Keysight/Keysight_N9030B.py +++ b/src/qcodes/instrument_drivers/Keysight/Keysight_N9030B.py @@ -98,7 +98,6 @@ def __init__( **kwargs: Unpack[InstrumentBaseKWArgs], ): super().__init__(parent, name, *arg, **kwargs) - self.root_instrument: KeysightN9030B self._additional_wait = additional_wait self._min_freq = -8e7 @@ -111,7 +110,7 @@ def __init__( } opt: str | None = None for hw_opt_for_max_freq in self._valid_max_freq: - if hw_opt_for_max_freq in self.root_instrument.options(): + if hw_opt_for_max_freq in self.parent.options(): opt = hw_opt_for_max_freq assert opt is not None self._max_freq = self._valid_max_freq[opt] @@ -454,23 +453,22 @@ def _get_data(self, trace_num: int) -> npt.NDArray[np.float64]: """ Gets data from the measurement. """ - root_instr = self.root_instrument # Check if we should run a new sweep - auto_sweep = root_instr.auto_sweep() + auto_sweep = self.parent.auto_sweep() if auto_sweep: # If we need to run a sweep, we need to set the timeout to take into account # the sweep time timeout = self.sweep_time() + self._additional_wait - with root_instr.timeout.set_to(timeout): - data = root_instr.visa_handle.query_binary_values( - f":READ:{root_instr.measurement()}{trace_num}?", + with self.parent.timeout.set_to(timeout): + data = self.parent.visa_handle.query_binary_values( + f":READ:{self.parent.measurement()}{trace_num}?", datatype="d", is_big_endian=False, ) else: - data = root_instr.visa_handle.query_binary_values( - f":FETC:{root_instr.measurement()}{trace_num}?", + data = self.parent.visa_handle.query_binary_values( + f":FETC:{self.parent.measurement()}{trace_num}?", datatype="d", is_big_endian=False, ) @@ -491,9 +489,9 @@ def setup_swept_sa_sweep(self, start: float, stop: float, npts: int) -> None: """ Sets up the Swept SA measurement sweep for Spectrum Analyzer Mode. """ - self.root_instrument.mode("SA") - if "SAN" in self.root_instrument.available_meas(): - self.root_instrument.measurement("SAN") + self.parent.mode("SA") + if "SAN" in self.parent.available_meas(): + self.parent.measurement("SAN") else: raise RuntimeError( "Swept SA measurement is not available on your " @@ -537,7 +535,7 @@ def __init__( } opt: str | None = None for hw_opt_for_max_freq in self._valid_max_freq: - if hw_opt_for_max_freq in self.root_instrument.options(): + if hw_opt_for_max_freq in self.parent.options(): opt = hw_opt_for_max_freq assert opt is not None self._max_freq = self._valid_max_freq[opt] @@ -668,9 +666,8 @@ def _get_data(self, trace_num: int) -> ParamRawDataType: """ Gets data from the measurement. """ - root_instr = self.root_instrument - measurement = root_instr.measurement() - raw_data = root_instr.visa_handle.query_binary_values( + measurement = self.parent.measurement() + raw_data = self.parent.visa_handle.query_binary_values( f":READ:{measurement}1?", datatype="d", is_big_endian=False, @@ -684,7 +681,7 @@ def _get_data(self, trace_num: int) -> ParamRawDataType: return -1 * np.ones(self.npts()) try: - data = root_instr.visa_handle.query_binary_values( + data = self.parent.visa_handle.query_binary_values( f":READ:{measurement}{trace_num}?", datatype="d", is_big_endian=False, @@ -701,9 +698,9 @@ def setup_log_plot_sweep( """ Sets up the Log Plot measurement sweep for Phase Noise Mode. """ - self.root_instrument.mode("PNOISE") - if "LPL" in self.root_instrument.available_meas(): - self.root_instrument.measurement("LPL") + self.parent.mode("PNOISE") + if "LPL" in self.parent.available_meas(): + self.parent.measurement("LPL") else: raise RuntimeError( "Log Plot measurement is not available on your " diff --git a/src/qcodes/instrument_drivers/Keysight/KtMAwg.py b/src/qcodes/instrument_drivers/Keysight/KtMAwg.py index 309a0aebacf9..f1c36735ae98 100644 --- a/src/qcodes/instrument_drivers/Keysight/KtMAwg.py +++ b/src/qcodes/instrument_drivers/Keysight/KtMAwg.py @@ -12,7 +12,7 @@ from typing_extensions import Unpack -class KeysightM9336AAWGChannel(InstrumentChannel): +class KeysightM9336AAWGChannel(InstrumentChannel["KeysightM9336A"]): """ Represent the three channels of the Keysight KTM Awg driver. The channels can be independently controlled and programmed with @@ -135,6 +135,12 @@ def __init__( ) """Parameter digital_gain""" + @property + def root_instrument(self) -> "KeysightM9336A": + root_instrument = super().root_instrument + assert isinstance(root_instrument, KeysightM9336A) + return root_instrument + def load_waveform(self, filename: str) -> None: path = ctypes.create_string_buffer(filename.encode("ascii")) self._awg_handle = ctypes.c_int32(-1) diff --git a/src/qcodes/instrument_drivers/Keysight/N52xx.py b/src/qcodes/instrument_drivers/Keysight/N52xx.py index 9ebc10609a82..cdca52023407 100644 --- a/src/qcodes/instrument_drivers/Keysight/N52xx.py +++ b/src/qcodes/instrument_drivers/Keysight/N52xx.py @@ -8,6 +8,7 @@ from qcodes.instrument import ( ChannelList, + ChannelTuple, InstrumentBaseKWArgs, InstrumentChannel, VisaInstrument, @@ -69,7 +70,7 @@ def get_raw(self) -> npt.NDArray: return np.linspace(0, self._stopparam(), self._pointsparam()) -class FormattedSweep(ParameterWithSetpoints): +class FormattedSweep(ParameterWithSetpoints[npt.NDArray, "KeysightPNATrace"]): """ Mag will run a sweep, including averaging, before returning data. As such, wait time in a loop is not needed. @@ -96,7 +97,7 @@ def setpoints(self) -> "Sequence[ParameterBase]": """ if self.instrument is None: raise RuntimeError("Cannot return setpoints if not attached to instrument") - root_instrument: KeysightPNABase = self.root_instrument # type: ignore[assignment] + root_instrument: KeysightPNABase = self.root_instrument sweep_type = root_instrument.sweep_type() if sweep_type == "LIN": return (root_instrument.frequency_axis,) @@ -115,10 +116,16 @@ def setpoints(self, setpoints: Any) -> None: """ return + @property + def root_instrument(self) -> "KeysightPNABase": + root_instrument = super().root_instrument + assert isinstance(root_instrument, KeysightPNABase) + return root_instrument + def get_raw(self) -> npt.NDArray: if self.instrument is None: raise RuntimeError("Cannot get data without instrument") - root_instr = self.instrument.root_instrument + root_instr = self.root_instrument # Check if we should run a new sweep auto_sweep = root_instr.auto_sweep() @@ -182,7 +189,7 @@ def _set_power_limits(self, min_power: float, max_power: float) -> None: "Alis for backwards compatibility" -class KeysightPNATrace(InstrumentChannel): +class KeysightPNATrace(InstrumentChannel["KeysightPNABase"]): """ Allow operations on individual PNA traces. """ @@ -292,6 +299,12 @@ def __init__( ) """Parameter polar""" + @property + def root_instrument(self) -> "KeysightPNABase": + root_instrument = super().root_instrument + assert isinstance(root_instrument, KeysightPNABase) + return root_instrument + def disable(self) -> None: """ Disable this trace on the PNA @@ -438,7 +451,11 @@ def __init__( ) ports.append(port) self.add_submodule(f"port{port_num}", port) - self.add_submodule("ports", ports.to_channel_tuple()) + + self.ports: ChannelTuple[KeysightPNAPort] = self.add_submodule( + "ports", ports.to_channel_tuple() + ) + """Tuple of KeysightPNAPort submodules""" # RF output self.output: Parameter = self.add_parameter( diff --git a/src/qcodes/instrument_drivers/Keysight/keysight_34934a.py b/src/qcodes/instrument_drivers/Keysight/keysight_34934a.py index a1dcdfbf1ae7..2ac7a6e7dec9 100644 --- a/src/qcodes/instrument_drivers/Keysight/keysight_34934a.py +++ b/src/qcodes/instrument_drivers/Keysight/keysight_34934a.py @@ -13,11 +13,11 @@ from qcodes.instrument import ( InstrumentBaseKWArgs, - InstrumentChannel, - VisaInstrument, ) from qcodes.parameters import Parameter + from .keysight_34980a import Keysight34980A + class Keysight34934A(Keysight34980ASwitchMatrixSubModule): """ @@ -32,7 +32,7 @@ class Keysight34934A(Keysight34980ASwitchMatrixSubModule): def __init__( self, - parent: "VisaInstrument | InstrumentChannel", + parent: "Keysight34980A", name: str, slot: int, **kwargs: "Unpack[InstrumentBaseKWArgs]", diff --git a/src/qcodes/instrument_drivers/Keysight/keysight_34980a_submodules.py b/src/qcodes/instrument_drivers/Keysight/keysight_34980a_submodules.py index 27cc8f879969..094154a53193 100644 --- a/src/qcodes/instrument_drivers/Keysight/keysight_34980a_submodules.py +++ b/src/qcodes/instrument_drivers/Keysight/keysight_34980a_submodules.py @@ -1,15 +1,17 @@ from typing import TYPE_CHECKING -from qcodes.instrument import InstrumentBaseKWArgs, InstrumentChannel, VisaInstrument +from qcodes.instrument import InstrumentBaseKWArgs, InstrumentChannel if TYPE_CHECKING: from typing_extensions import Unpack + from .keysight_34980a import Keysight34980A -class Keysight34980ASwitchMatrixSubModule(InstrumentChannel): + +class Keysight34980ASwitchMatrixSubModule(InstrumentChannel["Keysight34980A"]): def __init__( self, - parent: VisaInstrument | InstrumentChannel, + parent: "Keysight34980A", name: str, slot: int, **kwargs: "Unpack[InstrumentBaseKWArgs]", diff --git a/src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1500_base.py b/src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1500_base.py index e24ac2addf19..d25f4a970cb0 100644 --- a/src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1500_base.py +++ b/src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1500_base.py @@ -340,7 +340,7 @@ def self_calibration( """ msg = MessageBuilder().cal_query(slot=slot) - with self.root_instrument.timeout.set_to(self.calibration_time_out): + with self.timeout.set_to(self.calibration_time_out): response = self.ask(msg.message) return constants.CALResponse(int(response)) diff --git a/src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1500_module.py b/src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1500_module.py index 458bc0c75f7d..9a219dc378c6 100644 --- a/src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1500_module.py +++ b/src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1500_module.py @@ -373,7 +373,7 @@ def clear_timer_count(self) -> None: is not effective for the 4 byte binary data output format (FMT3 and FMT4). """ - self.root_instrument.clear_timer_count(chnum=self.channels) + self.parent.clear_timer_count(chnum=self.channels) class StatusMixin: diff --git a/src/qcodes/instrument_drivers/Keysight/private/Keysight_344xxA_submodules.py b/src/qcodes/instrument_drivers/Keysight/private/Keysight_344xxA_submodules.py index 2fcf0bca27ae..4166adca35f8 100644 --- a/src/qcodes/instrument_drivers/Keysight/private/Keysight_344xxA_submodules.py +++ b/src/qcodes/instrument_drivers/Keysight/private/Keysight_344xxA_submodules.py @@ -553,9 +553,8 @@ def _acquire_time_trace(self) -> npt.NDArray[np.float64]: npts = self.instrument.timetrace_npts() meas_time = npts * dt disp_text = f"Acquiring {npts} samples" # display limit: 40 characters - new_timeout = max( - self._acquire_timeout_fudge_factor * meas_time, self.instrument.timeout() - ) + old_timeout = self.instrument.timeout() or float("inf") + new_timeout = max(self._acquire_timeout_fudge_factor * meas_time, old_timeout) with ExitStack() as stack: stack.enter_context(self.instrument.trigger.count.set_to(1)) @@ -583,7 +582,7 @@ def get_raw(self) -> npt.NDArray[np.float64]: return data -class TimeAxis(Parameter): +class TimeAxis(Parameter[npt.NDArray, "Keysight344xxA"]): """ A simple :class:`.Parameter` that holds all the times (relative to the measurement start) at which the points of the time trace were acquired. diff --git a/src/qcodes/instrument_drivers/QDev/QDac_channels.py b/src/qcodes/instrument_drivers/QDev/QDac_channels.py index a908b32d947f..32c976ad1496 100644 --- a/src/qcodes/instrument_drivers/QDev/QDac_channels.py +++ b/src/qcodes/instrument_drivers/QDev/QDac_channels.py @@ -11,7 +11,7 @@ from qcodes import validators as vals from qcodes.instrument import ( ChannelList, - Instrument, + ChannelTuple, InstrumentBaseKWArgs, InstrumentChannel, VisaInstrument, @@ -29,7 +29,7 @@ log = logging.getLogger(__name__) -class QDevQDacChannel(InstrumentChannel): +class QDevQDacChannel(InstrumentChannel["QDevQDac"]): """ A single output channel of the QDac. @@ -40,7 +40,7 @@ class QDevQDacChannel(InstrumentChannel): def __init__( self, - parent: Instrument, + parent: "QDevQDac", name: str, channum: int, **kwargs: "Unpack[InstrumentBaseKWArgs]", @@ -143,9 +143,20 @@ def snapshot_base( update: bool | None = False, params_to_skip_update: "Sequence[str] | None" = None, ) -> dict[Any, Any]: - update_currents = self._parent._update_currents and update - if update and not self._parent._get_status_performed: - self._parent._update_cache(readcurrents=update_currents) + # setting update not None will override parent setting + # otherwise we use parent setting + # parent._update | update | do update + # True | True | True + # True | None | True + # True | False | False + # False | True | True + # False | None | False + # False | False | False + update_currents = ( + self.parent._update_currents and update is not False + ) or update is True + if update and not self.parent._get_status_performed: + self.parent._update_cache(readcurrents=update_currents) # call get_status rather than getting the status individually for # each parameter. This is only done if _get_status_performed is False # this is used to signal that the parent has already called it and @@ -287,7 +298,10 @@ def __init__( channels.append(channel) # Should raise valueerror if name is invalid (silently fails now) self.add_submodule(f"ch{i:02}", channel) - self.add_submodule("channels", channels.to_channel_tuple()) + self.channels: ChannelTuple[QDevQDacChannel] = self.add_submodule( + "channels", channels.to_channel_tuple() + ) + """ChannelTuple containing all QDevQDacChannel instances""" for board in range(6): for sensor in range(3): diff --git a/src/qcodes/instrument_drivers/american_magnetics/AMI430_visa.py b/src/qcodes/instrument_drivers/american_magnetics/AMI430_visa.py index de86d4996791..6ae5a78dc159 100644 --- a/src/qcodes/instrument_drivers/american_magnetics/AMI430_visa.py +++ b/src/qcodes/instrument_drivers/american_magnetics/AMI430_visa.py @@ -47,7 +47,7 @@ class AMI430Warning(UserWarning): pass -class AMI430SwitchHeater(InstrumentChannel): +class AMI430SwitchHeater(InstrumentChannel["AMIModel430"]): class _Decorators: @classmethod def check_enabled( @@ -140,14 +140,14 @@ def _check_enabled(self) -> bool: @_Decorators.check_enabled def _on(self) -> None: self.write("PS 1") - while self._parent.ramping_state() == "heating switch": - self._parent._sleep(0.5) + while self.parent.ramping_state() == "heating switch": + self.parent._sleep(0.5) @_Decorators.check_enabled def _off(self) -> None: self.write("PS 0") - while self._parent.ramping_state() == "cooling switch": - self._parent._sleep(0.5) + while self.parent.ramping_state() == "cooling switch": + self.parent._sleep(0.5) def _check_state(self) -> bool: if self.enabled() is False: @@ -231,8 +231,6 @@ def __init__( self._parent_instrument = None - # Add reset function - self.add_function("reset", call_cmd="*RST") if reset: self.reset() @@ -262,8 +260,8 @@ def __init__( """Parameter current_ramp_limit""" self.field_ramp_limit: Parameter = self.add_parameter( "field_ramp_limit", - get_cmd=self.current_ramp_limit, - set_cmd=self.current_ramp_limit, + get_cmd=self.current_ramp_limit.get, + set_cmd=self.current_ramp_limit.set, scale=1 / float(self.ask("COIL?")), unit="T/s", ) @@ -320,8 +318,7 @@ def __init__( "is_quenched", get_cmd="QU?", val_mapping={True: 1, False: 0} ) """Parameter is_quenched""" - self.add_function("reset_quench", call_cmd="QU 0") - self.add_function("set_quenched", call_cmd="QU 1") + self.ramping_state: Parameter = self.add_parameter( "ramping_state", get_cmd="STATE?", @@ -356,17 +353,39 @@ def __init__( ) """Submodule the switch heater submodule.""" - # Add interaction functions - self.add_function("get_error", call_cmd="SYST:ERR?") - self.add_function("ramp", call_cmd="RAMP") - self.add_function("pause", call_cmd="PAUSE") - self.add_function("zero", call_cmd="ZERO") - # Correctly assign all units self._update_units() self.connect_message() + def get_error(self) -> str: + """Get the last error from the instrument""" + return self.ask("SYST:ERR?") + + def ramp(self) -> None: + """Start ramping to the setpoint""" + self.write("RAMP") + + def pause(self) -> None: + """Pause ramping""" + self.write("PAUSE") + + def zero(self) -> None: + """Ramp to zero current""" + self.write("ZERO") + + def reset_quench(self) -> None: + """Reset a quench condition on the instrument""" + self.write("QU 0") + + def set_quenched(self) -> None: + """Set a quench condition on the instrument""" + self.write("QU 1") + + def reset(self) -> None: + """Reset the instrument to default settings""" + self.write("*RST") + def _sleep(self, t: float) -> None: """ Sleep for a number of seconds t. If we are or using @@ -1045,8 +1064,12 @@ def _adjust_child_instruments(self, values: tuple[float, float, float]) -> None: raise ValueError("_set_fields aborted; field would exceed limit") # Check if the individual instruments are ready - for name in ("x", "y", "z"): - instrument = getattr(self, f"_instrument_{name}") + Instruments_to_check = ( + self._instrument_x, + self._instrument_y, + self._instrument_z, + ) + for instrument in Instruments_to_check: if instrument.ramping_state() == "ramping": msg = f"_set_fields aborted; magnet {instrument} is already ramping" raise AMI430Exception(msg) diff --git a/src/qcodes/instrument_drivers/oxford/MercuryiPS_VISA.py b/src/qcodes/instrument_drivers/oxford/MercuryiPS_VISA.py index 9038c474f4dc..3922814c2c58 100644 --- a/src/qcodes/instrument_drivers/oxford/MercuryiPS_VISA.py +++ b/src/qcodes/instrument_drivers/oxford/MercuryiPS_VISA.py @@ -76,14 +76,14 @@ def _temp_parser(response: str) -> float: return float(response.rsplit(":", maxsplit=1)[-1][:-1]) -class OxfordMercuryWorkerPS(InstrumentChannel): +class OxfordMercuryWorkerPS(InstrumentChannel["OxfordMercuryiPS"]): """ Class to hold a worker power supply for the Oxford MercuryiPS """ def __init__( self, - parent: VisaInstrument, + parent: OxfordMercuryiPS, name: str, UID: str, **kwargs: Unpack[InstrumentBaseKWArgs], @@ -108,7 +108,7 @@ def __init__( # The firmware update from 2.5 -> 2.6 changed the command # syntax slightly - if version.parse(self.root_instrument.firmware) >= version.parse("2.6"): + if version.parse(self.parent.firmware) >= version.parse("2.6"): self.psu_string = "SPSU" else: self.psu_string = "PSU" @@ -348,10 +348,14 @@ def __init__( self.firmware = self.IDN()["firmware"] # TODO: Query instrument to ensure which PSUs are actually present - for grp in ["GRPX", "GRPY", "GRPZ"]: - psu_name = grp - psu = OxfordMercuryWorkerPS(self, psu_name, grp) - self.add_submodule(psu_name, psu) + GRPX = OxfordMercuryWorkerPS(self, "GRPX", "GRPX") + self.GRPX: OxfordMercuryWorkerPS = self.add_submodule("GRPX", GRPX) + + GRPY = OxfordMercuryWorkerPS(self, "GRPY", "GRPY") + self.GRPY: OxfordMercuryWorkerPS = self.add_submodule("GRPY", GRPY) + + GRPZ = OxfordMercuryWorkerPS(self, "GRPZ", "GRPZ") + self.GRPZ: OxfordMercuryWorkerPS = self.add_submodule("GRPZ", GRPZ) self._field_limits = field_limits if field_limits else lambda x, y, z: True diff --git a/src/qcodes/instrument_drivers/rigol/Rigol_DG4000.py b/src/qcodes/instrument_drivers/rigol/Rigol_DG4000.py index 8d0c2fafe546..873cfc7289d8 100644 --- a/src/qcodes/instrument_drivers/rigol/Rigol_DG4000.py +++ b/src/qcodes/instrument_drivers/rigol/Rigol_DG4000.py @@ -738,13 +738,15 @@ def __init__( # Trace self.add_function("upload_data", call_cmd=self._upload_data, args=[Anything()]) - self.add_function("reset", call_cmd="*RST") - if reset: self.reset() self.connect_message() + def reset(self) -> None: + """Reset the instrument to default settings.""" + self.write("*RST") + def _upload_data(self, data: "Sequence[float] | npt.NDArray") -> None: """ Upload data to the AWG memory. diff --git a/src/qcodes/instrument_drivers/rigol/Rigol_DS1074Z.py b/src/qcodes/instrument_drivers/rigol/Rigol_DS1074Z.py index e8a281589037..b3e196205544 100644 --- a/src/qcodes/instrument_drivers/rigol/Rigol_DS1074Z.py +++ b/src/qcodes/instrument_drivers/rigol/Rigol_DS1074Z.py @@ -19,7 +19,7 @@ from qcodes.parameters import Parameter -class RigolDS1074ZChannel(InstrumentChannel): +class RigolDS1074ZChannel(InstrumentChannel["RigolDS1074Z"]): """ Contains methods and attributes specific to the Rigol oscilloscope channels. @@ -58,9 +58,9 @@ def __init__( """Parameter trace""" def _get_full_trace(self) -> npt.NDArray: - y_ori = self.root_instrument.waveform_yorigin() - y_increm = self.root_instrument.waveform_yincrem() - y_ref = self.root_instrument.waveform_yref() + y_ori = self.parent.waveform_yorigin() + y_increm = self.parent.waveform_yincrem() + y_ref = self.parent.waveform_yref() y_raw = self._get_raw_trace() y_raw_shifted = y_raw - y_ori - y_ref full_data = np.multiply(y_raw_shifted, y_increm) @@ -68,13 +68,13 @@ def _get_full_trace(self) -> npt.NDArray: def _get_raw_trace(self) -> npt.NDArray: # set the out type from oscilloscope channels to WORD - self.root_instrument.write(":WAVeform:FORMat WORD") + self.parent.write(":WAVeform:FORMat WORD") # set the channel from where data will be obtained - self.root_instrument.data_source(f"ch{self.channel}") + self.parent.data_source(f"ch{self.channel}") # Obtain the trace - raw_trace_val = self.root_instrument.visa_handle.query_binary_values( + raw_trace_val = self.parent.visa_handle.query_binary_values( "WAV:DATA?", datatype="h", is_big_endian=False, expect_termination=False ) return np.array(raw_trace_val) @@ -231,10 +231,8 @@ def _get_time_axis(self) -> npt.NDArray: return xdata def _get_trigger_level(self) -> str: - trigger_level = self.root_instrument.ask( - f":TRIGger:{self.trigger_mode()}:LEVel?" - ) + trigger_level = self.ask(f":TRIGger:{self.trigger_mode()}:LEVel?") return trigger_level def _set_trigger_level(self, value: str) -> None: - self.root_instrument.write(f":TRIGger:{self.trigger_mode()}:LEVel {value}") + self.write(f":TRIGger:{self.trigger_mode()}:LEVel {value}") diff --git a/src/qcodes/instrument_drivers/rohde_schwarz/RTO1000.py b/src/qcodes/instrument_drivers/rohde_schwarz/RTO1000.py index 0be47302064d..94ad53118d20 100644 --- a/src/qcodes/instrument_drivers/rohde_schwarz/RTO1000.py +++ b/src/qcodes/instrument_drivers/rohde_schwarz/RTO1000.py @@ -1,6 +1,5 @@ # All manual references are to R&S RTO Digital Oscilloscope User Manual # for firmware 3.65, 2017 - import logging import time import warnings @@ -12,6 +11,8 @@ import qcodes.validators as vals from qcodes.instrument import ( + ChannelList, + ChannelTuple, Instrument, InstrumentBaseKWArgs, InstrumentChannel, @@ -26,9 +27,13 @@ log = logging.getLogger(__name__) -class ScopeTrace(ArrayParameter): +class ScopeTrace(ArrayParameter[npt.NDArray, "RohdeSchwarzRTO1000ScopeChannel"]): def __init__( - self, name: str, instrument: InstrumentChannel, channum: int, **kwargs: Any + self, + name: str, + instrument: "RohdeSchwarzRTO1000ScopeChannel", + channum: int, + **kwargs: Any, ) -> None: """ The ScopeTrace parameter is attached to a channel of the oscilloscope. @@ -53,6 +58,12 @@ def __init__( self.channum = channum self._trace_ready = False + @property + def root_instrument(self) -> "RohdeSchwarzRTO1000": + root_instrument = super().root_instrument + assert isinstance(root_instrument, RohdeSchwarzRTO1000) + return root_instrument + def prepare_trace(self) -> None: """ Prepare the scope for returning data, calculate the setpoints @@ -452,7 +463,7 @@ def __init__( ScopeMeasurement = RohdeSchwarzRTO1000ScopeMeasurement -class RohdeSchwarzRTO1000ScopeChannel(InstrumentChannel): +class RohdeSchwarzRTO1000ScopeChannel(InstrumentChannel["RohdeSchwarzRTO1000"]): """ Class to hold an input channel of the scope. @@ -462,7 +473,7 @@ class RohdeSchwarzRTO1000ScopeChannel(InstrumentChannel): def __init__( self, - parent: Instrument, + parent: "RohdeSchwarzRTO1000", name: str, channum: int, **kwargs: "Unpack[InstrumentBaseKWArgs]", @@ -631,12 +642,12 @@ def __init__( def _set_range(self, value: float) -> None: self.scale.cache.set(value / 10) - self._parent.write(f"CHANnel{self.channum}:RANGe {value}") + self.parent.write(f"CHANnel{self.channum}:RANGe {value}") def _set_scale(self, value: float) -> None: self.range.cache.set(value * 10) - self._parent.write(f"CHANnel{self.channum}:SCALe {value}") + self.parent.write(f"CHANnel{self.channum}:SCALe {value}") ScopeChannel = RohdeSchwarzRTO1000ScopeChannel @@ -966,15 +977,31 @@ def __init__( """Parameter error_next""" # Add the channels to the instrument + scope_channels = ChannelList( + self, "scope_channels", RohdeSchwarzRTO1000ScopeChannel + ) + """ChannelTuple holding the scope channels. + """ for ch in range(1, self.num_chans + 1): chan = RohdeSchwarzRTO1000ScopeChannel(self, f"channel{ch}", ch) + scope_channels.append(chan) self.add_submodule(f"ch{ch}", chan) - + self.scope_channels: ChannelTuple[RohdeSchwarzRTO1000ScopeChannel] = ( + self.add_submodule("scope_channels", scope_channels.to_channel_tuple()) + ) + measurements = ChannelList( + self, "measurements", RohdeSchwarzRTO1000ScopeMeasurement + ) for measId in range(1, self.num_meas + 1): measCh = RohdeSchwarzRTO1000ScopeMeasurement( self, f"measurement{measId}", measId ) + measurements.append(measCh) self.add_submodule(f"meas{measId}", measCh) + self.measurements: ChannelTuple[RohdeSchwarzRTO1000ScopeMeasurement] = ( + self.add_submodule("measurements", measurements.to_channel_tuple()) + ) + """ChannelTuple holding the scope measurements.""" self.add_function("stop", call_cmd="STOP") self.add_function("reset", call_cmd="*RST") @@ -1055,10 +1082,8 @@ def _make_traces_not_ready(self) -> None: """ Make the scope traces be not ready. """ - self.ch1.trace._trace_ready = False - self.ch2.trace._trace_ready = False - self.ch3.trace._trace_ready = False - self.ch4.trace._trace_ready = False + for chan in self.scope_channels: + chan.trace._trace_ready = False def _set_trigger_level(self, value: float) -> None: """ @@ -1071,7 +1096,7 @@ def _set_trigger_level(self, value: float) -> None: source = trans[self.trigger_source.get()] if source != 5: submodule = self.submodules[f"ch{source}"] - assert isinstance(submodule, InstrumentChannel) + assert isinstance(submodule, RohdeSchwarzRTO1000ScopeChannel) v_range = submodule.range() offset = submodule.offset() diff --git a/src/qcodes/instrument_drivers/rohde_schwarz/ZNB.py b/src/qcodes/instrument_drivers/rohde_schwarz/ZNB.py index 1529846de6c1..f8fad7c1f034 100644 --- a/src/qcodes/instrument_drivers/rohde_schwarz/ZNB.py +++ b/src/qcodes/instrument_drivers/rohde_schwarz/ZNB.py @@ -397,7 +397,7 @@ def get_raw(self) -> npt.NDArray[np.floating]: return self.instrument._get_sweep_data() -class RohdeSchwarzZNBChannel(InstrumentChannel): +class RohdeSchwarzZNBChannel(InstrumentChannel["RohdeSchwarzZNBBase"]): def __init__( self, parent: "RohdeSchwarzZNBBase", @@ -432,7 +432,7 @@ def __init__( if existing_trace_to_bind_to is None: self._tracename = f"Trc{channel}" else: - traces = self._parent.ask("CONFigure:TRACe:CATalog?") + traces = self.parent.ask("CONFigure:TRACe:CATalog?") if existing_trace_to_bind_to not in traces: raise RuntimeError( f"Trying to bind to" @@ -759,12 +759,10 @@ def __init__( def set_electrical_delay_auto(self) -> None: n = self._instrument_channel - self.root_instrument.write(f"SENS{n}:CORR:EDEL:AUTO ONCE") + self.parent.write(f"SENS{n}:CORR:EDEL:AUTO ONCE") def autoscale(self) -> None: - self.root_instrument.write( - f"DISPlay:TRACe1:Y:SCALe:AUTO ONCE, {self._tracename}" - ) + self.parent.write(f"DISPlay:TRACe1:Y:SCALe:AUTO ONCE, {self._tracename}") def _get_format(self, tracename: str) -> str: n = self._instrument_channel @@ -925,7 +923,7 @@ def _get_sweep_data(self, force_polar: bool = False) -> npt.NDArray: # preserve original state of the znb with self.status.set_to(1): - self.root_instrument.cont_meas_off() + self.parent.cont_meas_off() try: # if force polar is set, the SDAT data format will be used. # Here the data will be transferred as a complex number @@ -935,7 +933,7 @@ def _get_sweep_data(self, force_polar: bool = False) -> npt.NDArray: else: data_format_command = "FDAT" - with self.root_instrument.timeout.set_to(self._get_timeout()): + with self.parent.timeout.set_to(self._get_timeout()): # instrument averages over its last 'avg' number of sweeps # need to ensure averaged result is returned for _ in range(self.avg()): @@ -950,7 +948,7 @@ def _get_sweep_data(self, force_polar: bool = False) -> npt.NDArray: if self.format() in ["Polar", "Complex", "Smith", "Inverse Smith"]: data = data[0::2] + 1j * data[1::2] finally: - self.root_instrument.cont_meas_on() + self.parent.cont_meas_on() return data def setup_cw_sweep(self) -> None: @@ -977,7 +975,7 @@ def setup_cw_sweep(self) -> None: self.auto_sweep_time_enabled(True) # Set cont measurement off here so we don't have to send that command # while measuring later. - self.root_instrument.cont_meas_off() + self.parent.cont_meas_off() def setup_lin_sweep(self) -> None: """ @@ -985,7 +983,7 @@ def setup_lin_sweep(self) -> None: """ self.sweep_type("Linear") self.averaging_enabled(True) - self.root_instrument.cont_meas_on() + self.parent.cont_meas_on() def _check_cw_sweep(self) -> None: """ @@ -998,7 +996,7 @@ def _check_cw_sweep(self) -> None: f"mode, instead it is: {self.sweep_type()}" ) - if not self.root_instrument.rf_power(): + if not self.parent.rf_power(): log.warning("RF output is off when getting sweep data") # It is possible that the instrument and QCoDeS disagree about @@ -1016,7 +1014,7 @@ def _check_cw_sweep(self) -> None: # Set the format to complex. self.format("Complex") # Set cont measurement off. - self.root_instrument.cont_meas_off() + self.parent.cont_meas_off() # Cache the sweep time so it is up to date when setting timeouts self.sweep_time() @@ -1027,7 +1025,7 @@ def _get_cw_data(self) -> tuple[npt.NDArray, npt.NDArray]: self._check_cw_sweep() with self.status.set_to(1): - with self.root_instrument.timeout.set_to(self._get_timeout()): + with self.parent.timeout.set_to(self._get_timeout()): self.write(f"INIT{self._instrument_channel}:IMM; *WAI") data_str = self.ask(f"CALC{self._instrument_channel}:DATA? SDAT") data = np.array(data_str.rstrip().split(",")).astype("float64") @@ -1037,7 +1035,7 @@ def _get_cw_data(self) -> tuple[npt.NDArray, npt.NDArray]: return i, q def _get_timeout(self) -> float: - timeout = self.root_instrument.timeout() or float("+inf") + timeout = self.parent.timeout() or float("+inf") timeout = max(self.sweep_time.cache.get() * 1.5, timeout) return timeout diff --git a/src/qcodes/instrument_drivers/stahl/stahl.py b/src/qcodes/instrument_drivers/stahl/stahl.py index be300f5e7e76..b4d9c0d0358c 100644 --- a/src/qcodes/instrument_drivers/stahl/stahl.py +++ b/src/qcodes/instrument_drivers/stahl/stahl.py @@ -57,12 +57,12 @@ def inner(*args: Any) -> Any: return inner -class StahlChannel(InstrumentChannel): +class StahlChannel(InstrumentChannel["Stahl"]): acknowledge_reply = chr(6) def __init__( self, - parent: VisaInstrument, + parent: "Stahl", name: str, channel_number: int, **kwargs: "Unpack[InstrumentBaseKWArgs]", @@ -215,7 +215,10 @@ def __init__( self.add_submodule(name, channel) channels.append(channel) - self.add_submodule("channel", channels) + self.channels: ChannelList[StahlChannel] = self.add_submodule( + "channel", channels + ) + """List of channels""" self.temperature: Parameter = self.add_parameter( "temperature", diff --git a/src/qcodes/instrument_drivers/stanford_research/SR86x.py b/src/qcodes/instrument_drivers/stanford_research/SR86x.py index a7d7afa50cde..a206d0f5da9c 100644 --- a/src/qcodes/instrument_drivers/stanford_research/SR86x.py +++ b/src/qcodes/instrument_drivers/stanford_research/SR86x.py @@ -89,7 +89,7 @@ def get_raw(self) -> npt.NDArray: return self._capture_data -class SR86xBuffer(InstrumentChannel): +class SR86xBuffer(InstrumentChannel["SR86x"]): """ Buffer module for the SR86x drivers. @@ -529,7 +529,7 @@ def _get_raw_capture_data_block( f"({size_of_currently_captured_data}kB)" ) - values = self._parent.visa_handle.query_binary_values( + values = self.parent.visa_handle.query_binary_values( f"CAPTUREGET? {offset_in_kb}, {size_in_kb}", datatype="f", is_big_endian=False, @@ -621,7 +621,7 @@ def capture_samples(self, sample_count: int) -> dict[str, npt.NDArray]: return self.get_capture_data(sample_count) -class SR86xDataChannel(InstrumentChannel): +class SR86xDataChannel(InstrumentChannel["SR86x"]): """ Implements a data channel of SR86x lock-in amplifier. Parameters that are assigned to these channels get plotted on the display of the instrument. diff --git a/src/qcodes/instrument_drivers/tektronix/AWG5014.py b/src/qcodes/instrument_drivers/tektronix/AWG5014.py index 9f0d2cf4a828..9ca6c3d2d888 100644 --- a/src/qcodes/instrument_drivers/tektronix/AWG5014.py +++ b/src/qcodes/instrument_drivers/tektronix/AWG5014.py @@ -915,6 +915,9 @@ def get_sq_mode(self) -> str: def _pack_record( self, name: str, value: float | str | Sequence[Any] | npt.NDArray, dtype: str ) -> bytes: + def _pack_numpy_array(array: npt.NDArray) -> bytes: + return array.astype(" None: validator.validate(value) -class Tektronix70000AWGChannel(InstrumentChannel): +class Tektronix70000AWGChannel(InstrumentChannel["TektronixAWG70000Base"]): """ Class to hold a channel of the AWG. """ def __init__( self, - parent: Instrument, + parent: TektronixAWG70000Base, name: str, channel: int, **kwargs: Unpack[InstrumentBaseKWArgs], @@ -183,8 +182,8 @@ def __init__( self.channel = channel - num_channels = self.root_instrument.num_channels - self.model = self.root_instrument.model + num_channels = self.parent.num_channels + self.model = self.parent.model fg = "function generator" @@ -253,7 +252,7 @@ def __init__( label=f"Channel {channel} {fg} signal path", set_cmd=f"FGEN:CHANnel{channel}:PATH {{}}", get_cmd=f"FGEN:CHANnel{channel}:PATH?", - val_mapping=_fg_path_val_map[self.root_instrument.model], + val_mapping=_fg_path_val_map[self.parent.model], ) """Parameter fgen_signalpath""" @@ -430,7 +429,7 @@ def _set_fgfreq(self, channel: int, frequency: float) -> None: "Hz, minimum is 1 Hz" ) else: - self.root_instrument.write(f"FGEN:CHANnel{channel}:FREQuency {frequency}") + self.parent.write(f"FGEN:CHANnel{channel}:FREQuency {frequency}") def setWaveform(self, name: str) -> None: """ @@ -440,10 +439,10 @@ def setWaveform(self, name: str) -> None: name: The name of the waveform """ - if name not in self.root_instrument.waveformList: + if name not in self.parent.waveformList: raise ValueError("No such waveform in the waveform list") - self.root_instrument.write(f'SOURce{self.channel}:CASSet:WAVeform "{name}"') + self.parent.write(f'SOURce{self.channel}:CASSet:WAVeform "{name}"') def setSequenceTrack(self, seqname: str, tracknr: int) -> None: """ @@ -454,8 +453,7 @@ def setSequenceTrack(self, seqname: str, tracknr: int) -> None: tracknr: Which track to use (1 or 2) """ - - self.root_instrument.write( + self.parent.write( f'SOURCE{self.channel}:CASSet:SEQuence "{seqname}", {tracknr}' ) @@ -464,7 +462,7 @@ def clear_asset(self) -> None: Clear assigned assets on this channel """ - self.root_instrument.write(f"SOURce{self.channel}:CASSet:CLEAR") + self.parent.write(f"SOURce{self.channel}:CASSet:CLEAR") AWGChannel = Tektronix70000AWGChannel diff --git a/src/qcodes/instrument_drivers/tektronix/DPO7200xx.py b/src/qcodes/instrument_drivers/tektronix/DPO7200xx.py index f179ff5b71e5..729c28dcb52f 100644 --- a/src/qcodes/instrument_drivers/tektronix/DPO7200xx.py +++ b/src/qcodes/instrument_drivers/tektronix/DPO7200xx.py @@ -7,7 +7,7 @@ import textwrap import time from functools import partial -from typing import TYPE_CHECKING, Any, ClassVar, Generic +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Self import numpy as np import numpy.typing as npt @@ -220,7 +220,7 @@ def __init__( """ -class TektronixDPOWaveform(InstrumentChannel): +class TektronixDPOWaveform(InstrumentChannel["TektronixDPOChannel"]): """ This submodule retrieves data from waveform sources, e.g. channels. @@ -234,7 +234,7 @@ class TektronixDPOWaveform(InstrumentChannel): def __init__( self, - parent: InstrumentBase, + parent: "TektronixDPOChannel", name: str, identifier: str, **kwargs: "Unpack[InstrumentBaseKWArgs]", @@ -333,6 +333,12 @@ def __init__( ) """Parameter trace""" + @property + def root_instrument(self) -> "TektronixDPO7000xx": + root_instrument = super().root_instrument + assert isinstance(root_instrument, TektronixDPO7000xx) + return root_instrument + def _get_cmd(self, cmd_string: str) -> "Callable[[], str]": """ Parameters defined in this submodule require the correct @@ -351,7 +357,7 @@ def _get_trace_data(self) -> npt.NDArray: if not waveform.is_binary(): raw_data = self.root_instrument.visa_handle.query_ascii_values( - "CURVE?", container=np.array + "CURVE?", container=np.ndarray ) else: bytes_per_sample = waveform.bytes_per_sample() @@ -366,7 +372,7 @@ def _get_trace_data(self) -> npt.NDArray: "CURVE?", datatype=data_type, is_big_endian=is_big_endian, - container=np.array, + container=np.ndarray, ) return (raw_data - self.raw_data_offset()) * self.scale() + self.offset() @@ -418,7 +424,7 @@ def __init__( ) """Parameter is_big_endian""" - self.bytes_per_sample: Parameter = self.add_parameter( + self.bytes_per_sample: Parameter[int, Self] = self.add_parameter( "bytes_per_sample", get_cmd="WFMOutpre:BYT_Nr?", set_cmd="WFMOutpre:BYT_Nr {}", @@ -436,7 +442,7 @@ def __init__( """Parameter is_binary""" -class TektronixDPOChannel(InstrumentChannel): +class TektronixDPOChannel(InstrumentChannel[TektronixDPO7000xx]): """ The main channel module for the oscilloscope. The parameters defined here reflect the waveforms as they are displayed on @@ -445,7 +451,7 @@ class TektronixDPOChannel(InstrumentChannel): def __init__( self, - parent: Instrument | InstrumentChannel, + parent: TektronixDPO7000xx, name: str, channel_number: int, **kwargs: "Unpack[InstrumentBaseKWArgs]", @@ -522,15 +528,15 @@ def set_trace_length(self, value: int) -> None: value: The requested number of samples in the trace """ - if self.root_instrument.horizontal.record_length() < value: + if self.parent.horizontal.record_length() < value: raise ValueError( "Cannot set a trace length which is larger than " "the record length. Please switch to manual mode " "and adjust the record length first" ) - self.root_instrument.data.start_index(1) - self.root_instrument.data.stop_index(value) + self.parent.data.start_index(1) + self.parent.data.stop_index(value) def set_trace_time(self, value: float) -> None: """ @@ -538,7 +544,7 @@ def set_trace_time(self, value: float) -> None: value: The time over which a trace is desired. """ - sample_rate = self.root_instrument.horizontal.sample_rate() + sample_rate = self.parent.horizontal.sample_rate() required_sample_count = int(sample_rate * value) self.set_trace_length(required_sample_count) diff --git a/src/qcodes/instrument_drivers/yokogawa/Yokogawa_GS200.py b/src/qcodes/instrument_drivers/yokogawa/Yokogawa_GS200.py index 4eb217e6271f..dd82635aa478 100644 --- a/src/qcodes/instrument_drivers/yokogawa/Yokogawa_GS200.py +++ b/src/qcodes/instrument_drivers/yokogawa/Yokogawa_GS200.py @@ -38,7 +38,7 @@ class YokogawaGS200Exception(Exception): pass -class YokogawaGS200Monitor(InstrumentChannel): +class YokogawaGS200Monitor(InstrumentChannel["YokogawaGS200"]): """ Monitor part of the GS200. This is only enabled if it is installed in the GS200 (it is an optional extra). @@ -169,7 +169,7 @@ def state(self) -> int: def _get_measurement(self) -> float: if self._unit is None or self._range is None: raise YokogawaGS200Exception("Measurement module not initialized.") - if self._parent.auto_range.get() or (self._unit == "VOLT" and self._range < 1): + if self.parent.auto_range.get() or (self._unit == "VOLT" and self._range < 1): # Measurements will not work with autorange, or when # range is <1V. self._enabled = False @@ -207,7 +207,7 @@ def update_measurement_enabled( self.measure.unit = "V" -class YokogawaGS200Program(InstrumentChannel): +class YokogawaGS200Program(InstrumentChannel["YokogawaGS200"]): """ InstrumentModule that holds a Program for the YokoGawa GS200 diff --git a/src/qcodes/logger/instrument_logger.py b/src/qcodes/logger/instrument_logger.py index bb96129a6a56..2c7e15077303 100644 --- a/src/qcodes/logger/instrument_logger.py +++ b/src/qcodes/logger/instrument_logger.py @@ -107,8 +107,11 @@ class InstrumentFilter(logging.Filter): """ def __init__(self, instruments: InstrumentBase | Sequence[InstrumentBase]): + # avoid importing qcodes.instrument at module level to prevent circular imports + from qcodes.instrument import InstrumentBase # noqa: PLC0415 + super().__init__() - if not isinstance(instruments, collections.abc.Sequence): + if isinstance(instruments, InstrumentBase): instrument_seq: Sequence[str] = (instruments.full_name,) else: instrument_seq = [inst.full_name for inst in instruments] @@ -188,8 +191,14 @@ def filter_instrument( handlers = (myhandler,) elif not isinstance(handler, collections.abc.Sequence): handlers = (handler,) - else: + elif isinstance(handler, collections.abc.Sequence) and not isinstance( + handler, logging.Handler + ): handlers = handler + else: + raise TypeError( + f"handler must be a Handler or a Sequence of Handlers got {type(handler)}" + ) instrument_filter = InstrumentFilter(instrument) for h in handlers: diff --git a/src/qcodes/parameters/array_parameter.py b/src/qcodes/parameters/array_parameter.py index d069150e7e49..07744d8feb62 100644 --- a/src/qcodes/parameters/array_parameter.py +++ b/src/qcodes/parameters/array_parameter.py @@ -174,7 +174,7 @@ def __init__( if not is_sequence_of(shape, int): raise ValueError("shapes must be a tuple of ints, not " + repr(shape)) - self.shape = shape + self.shape: tuple[int, ...] = tuple(shape) # require one setpoint per dimension of shape sp_shape = (len(shape),) diff --git a/src/qcodes/parameters/parameter_base.py b/src/qcodes/parameters/parameter_base.py index 3344620fc2f4..bcd2aa8823ad 100644 --- a/src/qcodes/parameters/parameter_base.py +++ b/src/qcodes/parameters/parameter_base.py @@ -1413,8 +1413,8 @@ def update(self, other: Iterable[P]) -> None: def __iter__(self) -> Iterator[P]: return iter(self._dict) - def __contains__(self, item: object) -> bool: - return item in self._dict + def __contains__(self, x: object) -> bool: + return x in self._dict def __len__(self) -> int: return len(self._dict) diff --git a/src/qcodes/parameters/sequence_helpers.py b/src/qcodes/parameters/sequence_helpers.py index 6cf30f249299..3ce2344c47f8 100644 --- a/src/qcodes/parameters/sequence_helpers.py +++ b/src/qcodes/parameters/sequence_helpers.py @@ -59,6 +59,14 @@ def is_sequence_of( next_shape = cast("tuple[int, ...]", shape[1:]) + # ty currently cannot infer that depth is not None here + # when both branches of the if are taken into account + # the type is narrowed to be not None in both branches + if depth is None: + raise ValueError( + f"Could not infer depth. depth is {depth} and shape is {shape}" + ) + for item in obj: if depth > 1: if not is_sequence_of(item, types, depth=depth - 1, shape=next_shape): diff --git a/src/qcodes/utils/attribute_helpers.py b/src/qcodes/utils/attribute_helpers.py index b2b32a27fc3a..40bd652a5766 100644 --- a/src/qcodes/utils/attribute_helpers.py +++ b/src/qcodes/utils/attribute_helpers.py @@ -36,41 +36,42 @@ class DelegateAttributes: A list of attribute names (strings) to *not* delegate to any other dictionary or object. """ + if not TYPE_CHECKING: - def __getattr__(self, key: str) -> Any: - if key in self.omit_delegate_attrs: - raise AttributeError( - f"'{self.__class__.__name__}' does not delegate attribute {key}" - ) - - for name in self.delegate_attr_dicts: - if key == name: - # needed to prevent infinite loops! + def __getattr__(self, key: str) -> Any: + if key in self.omit_delegate_attrs: raise AttributeError( - f"dict '{key}' has not been created in object '{self.__class__.__name__}'" + f"'{self.__class__.__name__}' does not delegate attribute {key}" ) - try: - d = getattr(self, name, None) - if d is not None: - return d[key] - except KeyError: - pass - for name in self.delegate_attr_objects: - if key == name: - raise AttributeError( - f"object '{key}' has not been created in object '{self.__class__.__name__}'" - ) - try: - obj = getattr(self, name, None) - if obj is not None: - return getattr(obj, key) - except AttributeError: - pass + for name in self.delegate_attr_dicts: + if key == name: + # needed to prevent infinite loops! + raise AttributeError( + f"dict '{key}' has not been created in object '{self.__class__.__name__}'" + ) + try: + d = getattr(self, name, None) + if d is not None: + return d[key] + except KeyError: + pass + + for name in self.delegate_attr_objects: + if key == name: + raise AttributeError( + f"object '{key}' has not been created in object '{self.__class__.__name__}'" + ) + try: + obj = getattr(self, name, None) + if obj is not None: + return getattr(obj, key) + except AttributeError: + pass - raise AttributeError( - f"'{self.__class__.__name__}' object and its delegates have no attribute '{key}'" - ) + raise AttributeError( + f"'{self.__class__.__name__}' object and its delegates have no attribute '{key}'" + ) def __dir__(self) -> list[str]: names = list(super().__dir__()) diff --git a/tests/common.py b/tests/common.py index 3d2ea62c3f11..6bd662f5308e 100644 --- a/tests/common.py +++ b/tests/common.py @@ -92,7 +92,7 @@ def profile(func: Callable[P, T]) -> Callable[P, T]: """ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - profile_filename = func.__name__ + ".prof" + profile_filename = getattr(func, "__name__", "unknown_function") + ".prof" profiler = cProfile.Profile() result = profiler.runcall(func, *args, **kwargs) profiler.dump_stats(profile_filename) diff --git a/tests/dataset/measurement/test_measurement_context_manager.py b/tests/dataset/measurement/test_measurement_context_manager.py index 4377b81148af..7ca77b26ecbd 100644 --- a/tests/dataset/measurement/test_measurement_context_manager.py +++ b/tests/dataset/measurement/test_measurement_context_manager.py @@ -297,7 +297,7 @@ def test_unregister_parameter(DAC, DMM) -> None: not_parameters = [DAC, DMM, 0.0, 1] for notparam in not_parameters: with pytest.raises(ValueError): - meas.unregister_parameter(notparam) # pyright: ignore[reportArgumentType] + meas.unregister_parameter(notparam) # type: ignore # unregistering something not registered should silently "succeed" meas.unregister_parameter("totes_not_registered") diff --git a/tests/dataset/test_database_creation_and_upgrading.py b/tests/dataset/test_database_creation_and_upgrading.py index ea2e198ef3b9..8ee8bfab4544 100644 --- a/tests/dataset/test_database_creation_and_upgrading.py +++ b/tests/dataset/test_database_creation_and_upgrading.py @@ -52,6 +52,7 @@ from tests.common import error_caused_by, skip_if_no_fixtures from tests.dataset.conftest import temporarily_copied_DB +assert tests.dataset.__file__ is not None fixturepath = os.sep.join(tests.dataset.__file__.split(os.sep)[:-1]) fixturepath = os.path.join(fixturepath, "fixtures") diff --git a/tests/dataset/test_database_extract_runs.py b/tests/dataset/test_database_extract_runs.py index 7b0d727e8d49..af4598977c54 100644 --- a/tests/dataset/test_database_extract_runs.py +++ b/tests/dataset/test_database_extract_runs.py @@ -744,6 +744,7 @@ def test_old_versions_not_touched( _, new_v = get_db_version_and_newest_available_version(source_path) + assert tests.dataset.__file__ is not None fixturepath = os.sep.join(tests.dataset.__file__.split(os.sep)[:-1]) fixturepath = os.path.join( fixturepath, "fixtures", "db_files", "version2", "some_runs.db" diff --git a/tests/dataset/test_dataset_export.py b/tests/dataset/test_dataset_export.py index 5a377e28163f..0dbc7b7c8367 100644 --- a/tests/dataset/test_dataset_export.py +++ b/tests/dataset/test_dataset_export.py @@ -1579,13 +1579,13 @@ def test_multi_index_options_non_grid(mock_dataset_non_grid: DataSet) -> None: def test_multi_index_wrong_option(mock_dataset_non_grid: DataSet) -> None: with pytest.raises(ValueError, match="Invalid value for use_multi_index"): - mock_dataset_non_grid.to_xarray_dataset(use_multi_index=True) # pyright: ignore[reportArgumentType] + mock_dataset_non_grid.to_xarray_dataset(use_multi_index=True) # type: ignore with pytest.raises(ValueError, match="Invalid value for use_multi_index"): - mock_dataset_non_grid.to_xarray_dataset(use_multi_index=False) # pyright: ignore[reportArgumentType] + mock_dataset_non_grid.to_xarray_dataset(use_multi_index=False) # type: ignore with pytest.raises(ValueError, match="Invalid value for use_multi_index"): - mock_dataset_non_grid.to_xarray_dataset(use_multi_index="perhaps") # pyright: ignore[reportArgumentType] + mock_dataset_non_grid.to_xarray_dataset(use_multi_index="perhaps") # type: ignore def test_geneate_pandas_index() -> None: diff --git a/tests/dataset/test_fix_functions.py b/tests/dataset/test_fix_functions.py index 80155811b298..f7d95b3b5b93 100644 --- a/tests/dataset/test_fix_functions.py +++ b/tests/dataset/test_fix_functions.py @@ -18,6 +18,7 @@ from tests.common import skip_if_no_fixtures from tests.dataset.conftest import temporarily_copied_DB +assert tests.dataset.__file__ is not None fixturepath = os.sep.join(tests.dataset.__file__.split(os.sep)[:-1]) fixturepath = os.path.join(fixturepath, "fixtures") diff --git a/tests/dataset/test_measurement_extensions.py b/tests/dataset/test_measurement_extensions.py index 87a342b7323f..a641ed32a047 100644 --- a/tests/dataset/test_measurement_extensions.py +++ b/tests/dataset/test_measurement_extensions.py @@ -331,7 +331,7 @@ def test_dond_into_fails_with_together_sweeps( dond_into( datasaver, - TogetherSweep(sweep1, sweep2), # pyright: ignore [reportArgumentType] + TogetherSweep(sweep1, sweep2), # type: ignore meas1, ) _ = datasaver.dataset @@ -352,8 +352,8 @@ def test_dond_into_fails_with_groups(default_params, default_database_and_experi dond_into( datasaver, sweep1, - [meas1], # pyright: ignore [reportArgumentType] - [meas2], # pyright: ignore [reportArgumentType] + [meas1], # type: ignore + [meas2], # type: ignore ) _ = datasaver.dataset diff --git a/tests/drivers/keysight_b1500/b1500_driver_tests/test_b1500.py b/tests/drivers/keysight_b1500/b1500_driver_tests/test_b1500.py index e9fd04f2ed31..e7a2376e88b5 100644 --- a/tests/drivers/keysight_b1500/b1500_driver_tests/test_b1500.py +++ b/tests/drivers/keysight_b1500/b1500_driver_tests/test_b1500.py @@ -98,7 +98,7 @@ def test_submodule_access_by_channel(b1500: KeysightB1500) -> None: def test_enable_multiple_channels(b1500: KeysightB1500) -> None: mock_write = MagicMock() - b1500.write = mock_write + b1500.write: MagicMock = mock_write b1500.enable_channels([1, 2, 3]) @@ -107,7 +107,7 @@ def test_enable_multiple_channels(b1500: KeysightB1500) -> None: def test_disable_multiple_channels(b1500: KeysightB1500) -> None: mock_write = MagicMock() - b1500.write = mock_write + b1500.write: MagicMock = mock_write b1500.disable_channels([1, 2, 3]) @@ -116,7 +116,7 @@ def test_disable_multiple_channels(b1500: KeysightB1500) -> None: def test_use_nplc_for_high_speed_adc(b1500: KeysightB1500) -> None: mock_write = MagicMock() - b1500.write = mock_write + b1500.write: MagicMock = mock_write b1500.use_nplc_for_high_speed_adc() mock_write.assert_called_once_with("AIT 0,2") @@ -129,7 +129,7 @@ def test_use_nplc_for_high_speed_adc(b1500: KeysightB1500) -> None: def test_use_nplc_for_high_resolution_adc(b1500: KeysightB1500) -> None: mock_write = MagicMock() - b1500.write = mock_write + b1500.write: MagicMock = mock_write b1500.use_nplc_for_high_resolution_adc() mock_write.assert_called_once_with("AIT 1,2") @@ -142,7 +142,7 @@ def test_use_nplc_for_high_resolution_adc(b1500: KeysightB1500) -> None: def test_autozero_enabled(b1500: KeysightB1500) -> None: mock_write = MagicMock() - b1500.write = mock_write + b1500.write: MagicMock = mock_write assert b1500.autozero_enabled() is False @@ -159,7 +159,7 @@ def test_autozero_enabled(b1500: KeysightB1500) -> None: def test_use_manual_mode_for_high_speed_adc(b1500: KeysightB1500) -> None: mock_write = MagicMock() - b1500.write = mock_write + b1500.write: MagicMock = mock_write b1500.use_manual_mode_for_high_speed_adc() mock_write.assert_called_once_with("AIT 0,1") @@ -177,7 +177,7 @@ def test_use_manual_mode_for_high_speed_adc(b1500: KeysightB1500) -> None: def test_self_calibration_successful(b1500: KeysightB1500) -> None: mock_ask = MagicMock() - b1500.ask = mock_ask + b1500.ask: MagicMock = mock_ask mock_ask.return_value = "0" @@ -189,7 +189,7 @@ def test_self_calibration_successful(b1500: KeysightB1500) -> None: def test_self_calibration_failed(b1500: KeysightB1500) -> None: mock_ask = MagicMock() - b1500.ask = mock_ask + b1500.ask: MagicMock = mock_ask expected_response = CALResponse(1) + CALResponse(64) mock_ask.return_value = "65" @@ -207,7 +207,7 @@ def test_error_message(b1500: KeysightB1500) -> None: def test_clear_timer_count(b1500: KeysightB1500) -> None: mock_write = MagicMock() - b1500.write = mock_write + b1500.write: MagicMock = mock_write b1500.clear_timer_count() mock_write.assert_called_once_with("TSR") @@ -220,7 +220,7 @@ def test_clear_timer_count(b1500: KeysightB1500) -> None: def test_set_measuremet_mode(b1500: KeysightB1500) -> None: mock_write = MagicMock() - b1500.write = mock_write + b1500.write: MagicMock = mock_write b1500.set_measurement_mode(mode=constants.MM.Mode.SPOT, channels=[1, 2]) mock_write.assert_called_once_with("MM 1,1,2") @@ -228,7 +228,7 @@ def test_set_measuremet_mode(b1500: KeysightB1500) -> None: def test_get_measurement_mode(b1500: KeysightB1500) -> None: mock_ask = MagicMock() - b1500.ask = mock_ask + b1500.ask: MagicMock = mock_ask mock_ask.return_value = "MM 1,1,2" measurement_mode = b1500.get_measurement_mode() @@ -238,7 +238,7 @@ def test_get_measurement_mode(b1500: KeysightB1500) -> None: def test_get_response_format_and_mode(b1500: KeysightB1500) -> None: mock_ask = MagicMock() - b1500.ask = mock_ask + b1500.ask: MagicMock = mock_ask mock_ask.return_value = "FMT 1,1" measurement_mode = b1500.get_response_format_and_mode() @@ -248,7 +248,7 @@ def test_get_response_format_and_mode(b1500: KeysightB1500) -> None: def test_enable_smu_filters(b1500: KeysightB1500) -> None: mock_write = MagicMock() - b1500.write = mock_write + b1500.write: MagicMock = mock_write b1500.enable_smu_filters(True) mock_write.assert_called_once_with("FL 1") @@ -282,7 +282,7 @@ def test_error_message_is_called_after_setting_a_parameter( b1500: KeysightB1500, ) -> None: mock_ask = MagicMock() - b1500.ask = mock_ask + b1500.ask: MagicMock = mock_ask mock_ask.return_value = '+0,"No Error."' b1500.enable_smu_filters(True) diff --git a/tests/drivers/test_tektronix_AWG70000A.py b/tests/drivers/test_tektronix_AWG70000A.py index 004bb65ca9ca..18cfe3949f00 100644 --- a/tests/drivers/test_tektronix_AWG70000A.py +++ b/tests/drivers/test_tektronix_AWG70000A.py @@ -205,6 +205,7 @@ def test_seqxfile_from_fs(forged_sequence) -> None: # typing convenience make_seqx = TektronixAWG70000Base.make_SEQX_from_forged_sequence + assert auxfiles.__file__ is not None path_to_schema = auxfiles.__file__.replace("__init__.py", "awgSeqDataSets.xsd") with open(path_to_schema) as fid: diff --git a/tests/validators/test_literal.py b/tests/validators/test_literal.py index c549e8003afd..d1a88a521b3b 100644 --- a/tests/validators/test_literal.py +++ b/tests/validators/test_literal.py @@ -15,10 +15,10 @@ def test_literal_validator() -> None: a123_val.validate(1) with pytest.raises(ValueError, match="5 is not a member of "): - a123_val.validate(5, context="Outside range") # pyright: ignore[reportArgumentType] + a123_val.validate(5, context="Outside range") # type: ignore with pytest.raises(ValueError, match="some_str is not a member of "): - a123_val.validate("some_str", context="Wrong type") # pyright: ignore[reportArgumentType] + a123_val.validate("some_str", context="Wrong type") # type: ignore def test_literal_validator_repr() -> None: