Skip to content

Commit 907f666

Browse files
lifningahl
andauthored
Treat connection upgrade (HTTP 101) as normal success. (#1320)
* Treat connection upgrade (HTTP 101) as normal success. This fixes a bug wherein client code would be generated for a 2XX-range response on WebSocket endpoints, which in turn would unnecessarily trip an `assert!(response_types.len() <= 1)` in `extract_responses` when explicit status codes are provided in the OpenAPI document. (This also fixes a bug in such a scenario wherein *all* responses, including HTTP errors, would be set to Upgrade type) (As in dropshot#1548) * Per ahl PR feedback: Forego supporting the erroneous OpenAPI format (and manually update sample OpenAPI JSON for Nexus and Propolis accordingly) * Old, incorrect schema for WebSocket endpoints is explicitly an error * Further feedback: Support the previous format after all, and test that it continues to work * clean up --------- Co-authored-by: lif <> Co-authored-by: Adam H. Leventhal <ahl@oxide.computer>
1 parent c3776c7 commit 907f666

14 files changed

Lines changed: 747 additions & 63 deletions

progenitor-impl/src/method.rs

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,13 @@ impl Generator {
443443
},
444444
))
445445
.map(|v: Result<(OperationResponseStatus, &Response)>| {
446-
let (status_code, response) = v?;
446+
let (mut status_code, response) = v?;
447+
448+
// A previous version of dropshot websockets failed to
449+
// properly report responses; clean this up here.
450+
if dropshot_websocket && status_code == OperationResponseStatus::Default {
451+
status_code = OperationResponseStatus::Code(101);
452+
}
447453

448454
// We categorize responses as "typed" based on the
449455
// "application/json" content type, "upgrade" if it's a
@@ -472,7 +478,7 @@ impl Generator {
472478
};
473479

474480
OperationResponseKind::Type(typ)
475-
} else if dropshot_websocket {
481+
} else if status_code == OperationResponseStatus::Code(101) {
476482
OperationResponseKind::Upgrade
477483
} else if response.content.first().is_some() {
478484
OperationResponseKind::Raw
@@ -484,6 +490,7 @@ impl Generator {
484490
if matches!(
485491
status_code,
486492
OperationResponseStatus::Default
493+
| OperationResponseStatus::Code(101)
487494
| OperationResponseStatus::Code(200..=299)
488495
| OperationResponseStatus::Range(2)
489496
) {
@@ -516,15 +523,6 @@ impl Generator {
516523
});
517524
}
518525

519-
// Must accept HTTP 101 Switching Protocols
520-
if dropshot_websocket {
521-
responses.push(OperationResponse {
522-
status_code: OperationResponseStatus::Code(101),
523-
typ: OperationResponseKind::Upgrade,
524-
description: None,
525-
})
526-
}
527-
528526
let dropshot_paginated = self.dropshot_pagination_data(operation, &params, &responses);
529527

530528
if dropshot_websocket && dropshot_paginated.is_some() {
@@ -533,6 +531,17 @@ impl Generator {
533531
operation_id
534532
)));
535533
}
534+
if dropshot_websocket
535+
&& responses
536+
.iter()
537+
.find(|r| r.status_code == OperationResponseStatus::Code(101))
538+
.is_none()
539+
{
540+
return Err(Error::InvalidExtension(format!(
541+
"websocket endpoint {:?} must include an explicit 101 response code",
542+
operation_id
543+
)));
544+
}
536545

537546
Ok(OperationMethod {
538547
operation_id: sanitize(operation_id, Case::Snake),

progenitor-impl/tests/output/src/nexus_builder.rs

Lines changed: 140 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26332,6 +26332,26 @@ impl Client {
2633226332
builder::InstanceSerialConsoleStream::new(self)
2633326333
}
2633426334

26335+
///Connect to an instance's serial console
26336+
///
26337+
///Use `GET /v1/instances/{instance}/serial-console/stream` instead
26338+
///
26339+
///Sends a `GET` request to
26340+
/// `/organizations/{organization_name}/projects/{project_name}/instances/
26341+
/// {instance_name}/serial-console/stream_v2`
26342+
///
26343+
///```ignore
26344+
/// let response = client.instance_serial_console_stream_v2()
26345+
/// .organization_name(organization_name)
26346+
/// .project_name(project_name)
26347+
/// .instance_name(instance_name)
26348+
/// .send()
26349+
/// .await;
26350+
/// ```
26351+
pub fn instance_serial_console_stream_v2(&self) -> builder::InstanceSerialConsoleStreamV2<'_> {
26352+
builder::InstanceSerialConsoleStreamV2::new(self)
26353+
}
26354+
2633526355
///Boot an instance
2633626356
///
2633726357
///Use `POST /v1/instances/{instance}/start` instead
@@ -35228,9 +35248,7 @@ pub mod builder {
3522835248
///Sends a `GET` request to
3522935249
/// `/organizations/{organization_name}/projects/{project_name}/
3523035250
/// instances/{instance_name}/serial-console/stream`
35231-
pub async fn send(
35232-
self,
35233-
) -> Result<ResponseValue<reqwest::Upgraded>, Error<reqwest::Upgraded>> {
35251+
pub async fn send(self) -> Result<ResponseValue<reqwest::Upgraded>, Error<()>> {
3523435252
let Self {
3523535253
client,
3523635254
organization_name,
@@ -35277,7 +35295,118 @@ pub mod builder {
3527735295
let response = result?;
3527835296
match response.status().as_u16() {
3527935297
101u16 => ResponseValue::upgrade(response).await,
35280-
200..=299 => ResponseValue::upgrade(response).await,
35298+
_ => Err(Error::UnexpectedResponse(response)),
35299+
}
35300+
}
35301+
}
35302+
35303+
///Builder for [`Client::instance_serial_console_stream_v2`]
35304+
///
35305+
///[`Client::instance_serial_console_stream_v2`]: super::Client::instance_serial_console_stream_v2
35306+
#[derive(Debug, Clone)]
35307+
pub struct InstanceSerialConsoleStreamV2<'a> {
35308+
client: &'a super::Client,
35309+
organization_name: Result<types::Name, String>,
35310+
project_name: Result<types::Name, String>,
35311+
instance_name: Result<types::Name, String>,
35312+
}
35313+
35314+
impl<'a> InstanceSerialConsoleStreamV2<'a> {
35315+
pub fn new(client: &'a super::Client) -> Self {
35316+
Self {
35317+
client: client,
35318+
organization_name: Err("organization_name was not initialized".to_string()),
35319+
project_name: Err("project_name was not initialized".to_string()),
35320+
instance_name: Err("instance_name was not initialized".to_string()),
35321+
}
35322+
}
35323+
35324+
pub fn organization_name<V>(mut self, value: V) -> Self
35325+
where
35326+
V: std::convert::TryInto<types::Name>,
35327+
{
35328+
self.organization_name = value
35329+
.try_into()
35330+
.map_err(|_| "conversion to `Name` for organization_name failed".to_string());
35331+
self
35332+
}
35333+
35334+
pub fn project_name<V>(mut self, value: V) -> Self
35335+
where
35336+
V: std::convert::TryInto<types::Name>,
35337+
{
35338+
self.project_name = value
35339+
.try_into()
35340+
.map_err(|_| "conversion to `Name` for project_name failed".to_string());
35341+
self
35342+
}
35343+
35344+
pub fn instance_name<V>(mut self, value: V) -> Self
35345+
where
35346+
V: std::convert::TryInto<types::Name>,
35347+
{
35348+
self.instance_name = value
35349+
.try_into()
35350+
.map_err(|_| "conversion to `Name` for instance_name failed".to_string());
35351+
self
35352+
}
35353+
35354+
///Sends a `GET` request to
35355+
/// `/organizations/{organization_name}/projects/{project_name}/
35356+
/// instances/{instance_name}/serial-console/stream_v2`
35357+
pub async fn send(self) -> Result<ResponseValue<reqwest::Upgraded>, Error<types::Error>> {
35358+
let Self {
35359+
client,
35360+
organization_name,
35361+
project_name,
35362+
instance_name,
35363+
} = self;
35364+
let organization_name = organization_name.map_err(Error::InvalidRequest)?;
35365+
let project_name = project_name.map_err(Error::InvalidRequest)?;
35366+
let instance_name = instance_name.map_err(Error::InvalidRequest)?;
35367+
let url = format!(
35368+
"{}/organizations/{}/projects/{}/instances/{}/serial-console/stream_v2",
35369+
client.baseurl,
35370+
encode_path(&organization_name.to_string()),
35371+
encode_path(&project_name.to_string()),
35372+
encode_path(&instance_name.to_string()),
35373+
);
35374+
let mut header_map = ::reqwest::header::HeaderMap::with_capacity(1usize);
35375+
header_map.append(
35376+
::reqwest::header::HeaderName::from_static("api-version"),
35377+
::reqwest::header::HeaderValue::from_static(super::Client::api_version()),
35378+
);
35379+
#[allow(unused_mut)]
35380+
let mut request = client
35381+
.client
35382+
.get(url)
35383+
.headers(header_map)
35384+
.header(::reqwest::header::CONNECTION, "Upgrade")
35385+
.header(::reqwest::header::UPGRADE, "websocket")
35386+
.header(::reqwest::header::SEC_WEBSOCKET_VERSION, "13")
35387+
.header(
35388+
::reqwest::header::SEC_WEBSOCKET_KEY,
35389+
::base64::Engine::encode(
35390+
&::base64::engine::general_purpose::STANDARD,
35391+
::rand::random::<[u8; 16]>(),
35392+
),
35393+
)
35394+
.build()?;
35395+
let info = OperationInfo {
35396+
operation_id: "instance_serial_console_stream_v2",
35397+
};
35398+
client.pre(&mut request, &info).await?;
35399+
let result = client.exec(request, &info).await;
35400+
client.post(&result, &info).await?;
35401+
let response = result?;
35402+
match response.status().as_u16() {
35403+
101u16 => ResponseValue::upgrade(response).await,
35404+
400u16..=499u16 => Err(Error::ErrorResponse(
35405+
ResponseValue::from_response(response).await?,
35406+
)),
35407+
500u16..=599u16 => Err(Error::ErrorResponse(
35408+
ResponseValue::from_response(response).await?,
35409+
)),
3528135410
_ => Err(Error::UnexpectedResponse(response)),
3528235411
}
3528335412
}
@@ -47944,9 +48073,7 @@ pub mod builder {
4794448073

4794548074
///Sends a `GET` request to
4794648075
/// `/v1/instances/{instance}/serial-console/stream`
47947-
pub async fn send(
47948-
self,
47949-
) -> Result<ResponseValue<reqwest::Upgraded>, Error<reqwest::Upgraded>> {
48076+
pub async fn send(self) -> Result<ResponseValue<reqwest::Upgraded>, Error<types::Error>> {
4795048077
let Self {
4795148078
client,
4795248079
instance,
@@ -47996,7 +48123,12 @@ pub mod builder {
4799648123
let response = result?;
4799748124
match response.status().as_u16() {
4799848125
101u16 => ResponseValue::upgrade(response).await,
47999-
200..=299 => ResponseValue::upgrade(response).await,
48126+
400u16..=499u16 => Err(Error::ErrorResponse(
48127+
ResponseValue::from_response(response).await?,
48128+
)),
48129+
500u16..=599u16 => Err(Error::ErrorResponse(
48130+
ResponseValue::from_response(response).await?,
48131+
)),
4800048132
_ => Err(Error::UnexpectedResponse(response)),
4800148133
}
4800248134
}

0 commit comments

Comments
 (0)