Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 72 additions & 85 deletions src/generators/go/http/sigil_emit_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
use std::collections::{BTreeMap, BTreeSet, HashSet};

use crate::codegen::traits::file_writer::FileInfo;
use crate::generators::multipart::{MultipartValueEncoding, multipart_parts_for_request_body};
use crate::ir::types::{
IrObject, IrOperation, IrParameter, IrPrimitive, IrRequestBody, IrResponse, IrSchemaKind,
IrSpec, IrTypeExpr, ParameterLocation,
IrOperation, IrParameter, IrPrimitive, IrRequestBody, IrResponse, IrSpec, IrTypeExpr,
ParameterLocation,
};
use heck::{ToLowerCamelCase, ToPascalCase, ToSnakeCase};
use sigil_stitch::code_block::CodeBlock;
Expand Down Expand Up @@ -141,6 +142,7 @@ fn collect_body_imports(plans: &[OpPlan<'_>], module_path: &str) -> Vec<ImportSp
if let Some(parts) = &body.multipart_parts {
pkgs.insert("bytes".to_string());
pkgs.insert("mime/multipart".to_string());
pkgs.insert("net/textproto".to_string());
for part in parts {
collect_stringify_imports(&part.type_expr, &mut pkgs);
}
Expand Down Expand Up @@ -454,15 +456,10 @@ struct MultipartPart {
type_expr: IrTypeExpr,
is_binary: bool,
required: bool,
content_type: String,
value_encoding: MultipartValueEncoding,
}

#[derive(Clone, Copy, PartialEq, Eq)]
enum MultipartValueEncoding {
Text,
Json,
}

#[derive(Clone, Copy, PartialEq, Eq)]
enum BodyEncoding {
Json,
Expand Down Expand Up @@ -556,7 +553,7 @@ fn plan_body(
};
let var_name = unique_name("body", used_names);
let multipart_parts = if media_type_base(&media_type) == "multipart/form-data" {
multipart_parts_for(&t, ir)
multipart_parts_for(b, &media_type, ir)
} else {
None
};
Expand Down Expand Up @@ -902,22 +899,45 @@ fn emit_required_multipart_part(
part: &MultipartPart,
value_expr: &str,
) {
cb.add("{", ());
cb.add_line();
cb.add("partHeader := textproto.MIMEHeader{}", ());
cb.add_line();
let disposition = if part.is_binary {
format!(
"form-data; name={}; filename={}",
go_string_literal(&part.wire_name),
go_string_literal(&part.wire_name)
)
} else {
format!("form-data; name={}", go_string_literal(&part.wire_name))
};
cb.add(
&format!(
"partHeader.Set(\"Content-Disposition\", {})",
go_string_literal(&disposition)
),
(),
);
cb.add_line();
cb.add(
&format!(
"partHeader.Set(\"Content-Type\", {})",
go_string_literal(&part.content_type)
),
(),
);
cb.add_line();
cb.add("partWriter, err := writer.CreatePart(partHeader)", ());
cb.add_line();
cb.begin_control_flow("if err != nil", ());
cb.add(
"return nil, fmt.Errorf(\"create multipart part: %%w\", err)",
(),
);
cb.add_line();
cb.end_control_flow();
if part.is_binary {
cb.add(
&format!(
"partWriter, err := writer.CreateFormFile(\"{}\", \"{}\")",
part.wire_name, part.wire_name
),
(),
);
cb.add_line();
cb.begin_control_flow("if err != nil", ());
cb.add(
"return nil, fmt.Errorf(\"create multipart file: %%w\", err)",
(),
);
cb.add_line();
cb.end_control_flow();
cb.begin_control_flow(
&format!("if _, err := partWriter.Write({value_expr}); err != nil"),
(),
Expand All @@ -938,24 +958,23 @@ fn emit_required_multipart_part(
);
cb.add_line();
cb.end_control_flow();
cb.begin_control_flow(
&format!(
"if err := writer.WriteField(\"{}\", string(partValue)); err != nil",
part.wire_name
),
(),
);
cb.begin_control_flow("if _, err := partWriter.Write(partValue); err != nil", ());
cb.add(
"return nil, fmt.Errorf(\"write multipart field: %%w\", err)",
(),
);
cb.add_line();
cb.end_control_flow();
} else if part.value_encoding == MultipartValueEncoding::Unsupported {
cb.add(
"return nil, fmt.Errorf(\"unsupported multipart part content type\")",
(),
);
cb.add_line();
} else {
cb.begin_control_flow(
&format!(
"if err := writer.WriteField(\"{}\", {}); err != nil",
part.wire_name,
"if _, err := io.WriteString(partWriter, {}); err != nil",
render_value_as_string(value_expr, &part.type_expr)
),
(),
Expand All @@ -967,6 +986,8 @@ fn emit_required_multipart_part(
cb.add_line();
cb.end_control_flow();
}
cb.add("}", ());
cb.add_line();
}

fn emit_decode_into(field: &str, go_ty: &str, decoding: ResponseDecoding) -> CodeBlock {
Expand Down Expand Up @@ -1174,63 +1195,29 @@ fn is_xml_media_type(media_type: &str) -> bool {
base == "application/xml" || base == "text/xml" || base.ends_with("+xml")
}

fn multipart_parts_for(t: &IrTypeExpr, ir: &IrSpec) -> Option<Vec<MultipartPart>> {
// TODO: Honor OpenAPI multipart encoding metadata once it is represented in the IR.
resolve_object(t, ir).map(|obj| {
obj.properties
.iter()
.map(|(wire_name, prop)| MultipartPart {
wire_name: wire_name.clone(),
field_name: go_field_name(wire_name),
type_expr: prop.type_expr.clone(),
is_binary: is_binary_type(&prop.type_expr, ir),
required: prop.required && !prop.nullable,
value_encoding: multipart_value_encoding(&prop.type_expr, ir),
fn multipart_parts_for(
body: &IrRequestBody,
media_type: &str,
ir: &IrSpec,
) -> Option<Vec<MultipartPart>> {
multipart_parts_for_request_body(body, media_type, ir).map(|parts| {
parts
.into_iter()
.map(|part| MultipartPart {
field_name: go_field_name(&part.wire_name),
wire_name: part.wire_name,
type_expr: part.type_expr,
is_binary: part.is_binary,
required: part.required,
content_type: part.content_type,
value_encoding: part.value_encoding,
})
.collect()
})
}

fn resolve_object<'a>(expr: &IrTypeExpr, ir: &'a IrSpec) -> Option<&'a IrObject> {
match expr {
IrTypeExpr::Named(name) => match ir.schemas.get(name).map(|schema| &schema.kind) {
Some(IrSchemaKind::Object(obj)) => Some(obj),
Some(IrSchemaKind::Alias(inner)) => resolve_object(inner, ir),
_ => None,
},
IrTypeExpr::Nullable(inner) => resolve_object(inner, ir),
_ => None,
}
}

fn is_binary_type(expr: &IrTypeExpr, ir: &IrSpec) -> bool {
match expr {
IrTypeExpr::Primitive(IrPrimitive::Binary) => true,
IrTypeExpr::Nullable(inner) => is_binary_type(inner, ir),
IrTypeExpr::Named(name) => ir.schemas.get(name).is_some_and(|schema| {
matches!(&schema.kind, IrSchemaKind::Alias(inner) if is_binary_type(inner, ir))
}),
_ => false,
}
}

fn multipart_value_encoding(expr: &IrTypeExpr, ir: &IrSpec) -> MultipartValueEncoding {
if is_multipart_text_type(expr, ir) {
MultipartValueEncoding::Text
} else {
MultipartValueEncoding::Json
}
}

fn is_multipart_text_type(expr: &IrTypeExpr, ir: &IrSpec) -> bool {
match expr {
IrTypeExpr::Primitive(_) | IrTypeExpr::StringLiteral(_) | IrTypeExpr::StringEnum(_) => true,
IrTypeExpr::Nullable(inner) => is_multipart_text_type(inner, ir),
IrTypeExpr::Named(name) => ir.schemas.get(name).is_some_and(|schema| {
matches!(&schema.kind, IrSchemaKind::Alias(inner) if is_multipart_text_type(inner, ir))
}),
_ => false,
}
fn go_string_literal(value: &str) -> String {
format!("{value:?}")
}

// ---------------------------------------------------------------------------
Expand Down
Loading
Loading