diff --git a/Cargo.lock b/Cargo.lock index bebbf9a9..fb8ac725 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -383,6 +383,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee2626afccd7561a06cf1367e2950c4718ea04565e20fb5029b6c7d8ad09abcf" +[[package]] +name = "either" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" + [[package]] name = "encode_unicode" version = "0.3.6" @@ -850,6 +856,15 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.1" @@ -958,6 +973,16 @@ version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" +[[package]] +name = "mime_guess" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "mio" version = "0.8.4" @@ -1191,7 +1216,7 @@ dependencies = [ "progenitor-macro", "project-root", "rand", - "regress", + "regress 0.5.0", "reqwest", "schemars", "serde", @@ -1224,6 +1249,7 @@ dependencies = [ "http", "hyper", "indexmap", + "itertools", "openapiv3", "proc-macro2", "quote", @@ -1346,6 +1372,16 @@ dependencies = [ "memchr", ] +[[package]] +name = "regress" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82a9ecfa0cb04d0b04dddb99b8ccf4f66bc8dfd23df694b398570bd8ae3a50fb" +dependencies = [ + "hashbrown 0.13.2", + "memchr", +] + [[package]] name = "remove_dir_all" version = "0.5.3" @@ -1375,6 +1411,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "once_cell", "percent-encoding", @@ -2102,8 +2139,8 @@ checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" [[package]] name = "typify" -version = "0.0.12-dev" -source = "git+https://github.com/oxidecomputer/typify#6d77f63b3dc5312cd36549507f941cb5d783600e" +version = "0.0.12" +source = "git+https://github.com/drahnr/typify?branch=main#28db444b60335af3db6e72e9be7f121148b8fdef" dependencies = [ "typify-impl", "typify-macro", @@ -2111,14 +2148,14 @@ dependencies = [ [[package]] name = "typify-impl" -version = "0.0.12-dev" -source = "git+https://github.com/oxidecomputer/typify#6d77f63b3dc5312cd36549507f941cb5d783600e" +version = "0.0.12" +source = "git+https://github.com/drahnr/typify?branch=main#28db444b60335af3db6e72e9be7f121148b8fdef" dependencies = [ "heck", "log", "proc-macro2", "quote", - "regress", + "regress 0.6.0", "schemars", "serde_json", "syn 2.0.8", @@ -2128,8 +2165,8 @@ dependencies = [ [[package]] name = "typify-macro" -version = "0.0.12-dev" -source = "git+https://github.com/oxidecomputer/typify#6d77f63b3dc5312cd36549507f941cb5d783600e" +version = "0.0.12" +source = "git+https://github.com/drahnr/typify?branch=main#28db444b60335af3db6e72e9be7f121148b8fdef" dependencies = [ "proc-macro2", "quote", @@ -2147,6 +2184,15 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56dee185309b50d1f11bfedef0fe6d036842e3fb77413abef29f8f8d1c5d4c1c" +[[package]] +name = "unicase" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.7" diff --git a/Cargo.toml b/Cargo.toml index 5539eed9..7254a801 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,8 +11,9 @@ members = [ #[patch."https://github.com/oxidecomputer/dropshot"] #dropshot = { path = "../dropshot/dropshot" } -#[patch."https://github.com/oxidecomputer/typify"] -#typify = { path = "../typify/typify" } +[patch."https://github.com/oxidecomputer/typify"] +typify = { git = "https://github.com/drahnr/typify", branch = "main" } +# typify = { path = "../typify/typify" } #[patch.crates-io] #serde_tokenstream = { path = "../serde_tokenstream" } diff --git a/README.md b/README.md index 4c4295dc..b8776ae0 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ You'll need to add the following to `Cargo.toml`: [dependencies] +futures = "0.3" +progenitor = { git = "https://github.com/oxidecomputer/progenitor" } -+reqwest = { version = "0.11", features = ["json", "stream"] } ++reqwest = { version = "0.11", features = ["json", "stream", "multipart"] } +serde = { version = "1.0", features = ["derive"] } ``` @@ -123,7 +123,7 @@ You'll need to add the following to `Cargo.toml`: [dependencies] +futures = "0.3" +progenitor-client = { git = "https://github.com/oxidecomputer/progenitor" } -+reqwest = { version = "0.11", features = ["json", "stream"] } ++reqwest = { version = "0.11", features = ["json", "stream", "multipart"] } +serde = { version = "1.0", features = ["derive"] } [build-dependencies] @@ -180,7 +180,7 @@ bytes = "1.3.0" chrono = { version = "0.4.23", default-features=false, features = ["serde"] } futures-core = "0.3.25" percent-encoding = "2.2.0" -reqwest = { version = "0.11.13", default-features=false, features = ["json", "stream"] } +reqwest = { version = "0.11.13", default-features=false, features = ["json", "stream", "multipart"] } serde = { version = "1.0.152", features = ["derive"] } serde_urlencoded = "0.7.1" diff --git a/example-build/Cargo.toml b/example-build/Cargo.toml index d1eaac93..044368f0 100644 --- a/example-build/Cargo.toml +++ b/example-build/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] chrono = { version = "0.4", features = ["serde"] } progenitor-client = { path = "../progenitor-client" } -reqwest = { version = "0.11.16", features = ["json", "stream"] } +reqwest = { version = "0.11.16", features = ["json", "stream", "multipart"] } base64 = "0.21" rand = "0.8" serde = { version = "1.0", features = ["derive"] } diff --git a/example-macro/Cargo.toml b/example-macro/Cargo.toml index 6648fc2f..1b862e1e 100644 --- a/example-macro/Cargo.toml +++ b/example-macro/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] chrono = { version = "0.4", features = ["serde"] } progenitor = { path = "../progenitor" } -reqwest = { version = "0.11.16", features = ["json", "stream"] } +reqwest = { version = "0.11.16", features = ["json", "stream", "multipart"] } schemars = { version = "0.8.12", features = ["uuid1"] } serde = { version = "1.0", features = ["derive"] } uuid = { version = "1.3", features = ["serde", "v4"] } diff --git a/progenitor-client/Cargo.toml b/progenitor-client/Cargo.toml index c06d614d..50ec9976 100644 --- a/progenitor-client/Cargo.toml +++ b/progenitor-client/Cargo.toml @@ -10,7 +10,7 @@ description = "An OpenAPI client generator - client support" bytes = "1.4.0" futures-core = "0.3.27" percent-encoding = "2.2" -reqwest = { version = "0.11.16", default-features = false, features = ["json", "stream"] } +reqwest = { version = "0.11.16", default-features = false, features = ["json", "stream", "multipart"] } serde = "1.0" serde_json = "1.0" serde_urlencoded = "0.7.1" diff --git a/progenitor-client/src/progenitor_client.rs b/progenitor-client/src/progenitor_client.rs index f2951f9d..d1ed8d18 100644 --- a/progenitor-client/src/progenitor_client.rs +++ b/progenitor-client/src/progenitor_client.rs @@ -4,11 +4,14 @@ //! Support code for generated clients. -use std::ops::{Deref, DerefMut}; +use std::{ + borrow::Cow, + ops::{Deref, DerefMut}, fs, +}; use bytes::Bytes; use futures_core::Stream; -use reqwest::RequestBuilder; +use reqwest::{RequestBuilder, multipart}; use serde::{de::DeserializeOwned, Serialize}; type InnerByteStream = @@ -63,10 +66,13 @@ impl ResponseValue { ) -> Result> { let status = response.status(); let headers = response.headers().clone(); - let inner = response - .json() - .await - .map_err(Error::InvalidResponsePayload)?; + let response = response.text().await?; + fs::write("responsefcuk.json", &response).unwrap(); + let inner: T = serde_json::from_str(&response).unwrap(); + // let inner = response + // .json() + // .await + // .map_err(|e| Error::InvalidResponsePayload(reqwest::Error::from(e)))?; Ok(Self { inner, @@ -386,11 +392,22 @@ pub fn encode_path(pc: &str) -> String { } #[doc(hidden)] -pub trait RequestBuilderExt { +pub trait RequestBuilderExt +where + Self: Sized, +{ fn form_urlencoded( self, body: &T, ) -> Result>; + + fn form_from_raw< + S: AsRef, + I: Sized + IntoIterator, + >( + self, + iter: I, + ) -> Result>; } impl RequestBuilderExt for RequestBuilder { @@ -405,8 +422,33 @@ impl RequestBuilderExt for RequestBuilder { "application/x-www-form-urlencoded", ), ) - .body(serde_urlencoded::to_string(body).map_err(|_| { - Error::InvalidRequest("failed to serialize body".to_string()) + .body(serde_urlencoded::to_string(body).map_err(|e| { + Error::InvalidRequest(format!( + "failed to serialize body: {e:?}" + )) })?)) } + + fn form_from_raw< + S: AsRef, + I: Sized + IntoIterator, + >( + self, + mut iter: I, + ) -> Result> { + use reqwest::multipart::{Form, Part}; + + let mut form = Form::new(); + for (name, part) in iter { + form = form.part( + name.as_ref().to_owned(), part, + ); + } + + dbg!(&form); + + Ok(self + .multipart(form) + ) + } } diff --git a/progenitor-impl/Cargo.toml b/progenitor-impl/Cargo.toml index d6fabeb6..94e9c24e 100644 --- a/progenitor-impl/Cargo.toml +++ b/progenitor-impl/Cargo.toml @@ -11,6 +11,7 @@ readme = "../README.md" heck = "0.4.1" getopts = "0.2" indexmap = "1.9" +itertools = "0.10" openapiv3 = "1.0.0" proc-macro2 = "1.0" quote = "1.0" diff --git a/progenitor-impl/src/cli.rs b/progenitor-impl/src/cli.rs index f7d265d4..66d185e8 100644 --- a/progenitor-impl/src/cli.rs +++ b/progenitor-impl/src/cli.rs @@ -7,6 +7,7 @@ use quote::{format_ident, quote}; use typify::{TypeSpaceImpl, TypeStructPropInfo}; use crate::{ + Security, method::{ OperationParameterKind, OperationParameterType, OperationResponseStatus, }, @@ -49,6 +50,8 @@ impl Generator { ) -> Result { validate_openapi(spec)?; + let security = Security::from(spec); + // Convert our components dictionary to schemars let schemas = spec.components.iter().flat_map(|components| { components.schemas.iter().map(|(name, ref_or_schema)| { @@ -74,6 +77,7 @@ impl Generator { &spec.components, path, method, + &security, path_parameters, ) }) diff --git a/progenitor-impl/src/lib.rs b/progenitor-impl/src/lib.rs index 331c3b1a..1b9e39d6 100644 --- a/progenitor-impl/src/lib.rs +++ b/progenitor-impl/src/lib.rs @@ -2,12 +2,16 @@ use std::collections::{HashMap, HashSet}; -use openapiv3::OpenAPI; +use indexmap::IndexMap; +use openapiv3::{OpenAPI, SecurityRequirement, SecurityScheme}; use proc_macro2::TokenStream; use quote::quote; use serde::Deserialize; +use template::PathTemplate; use thiserror::Error; +use typify::{TypeDetails, TypeId}; use typify::{TypeSpace, TypeSpaceSettings}; +use util::ReferenceOrExt; use crate::to_schema::ToSchema; @@ -40,6 +44,7 @@ pub type Result = std::result::Result; pub struct Generator { type_space: TypeSpace, + forms: HashSet, settings: GenerationSettings, uses_futures: bool, uses_websockets: bool, @@ -163,6 +168,7 @@ impl Default for Generator { type_space: TypeSpace::new( TypeSpaceSettings::default().with_type_mod("types"), ), + forms: Default::default(), settings: Default::default(), uses_futures: Default::default(), uses_websockets: Default::default(), @@ -170,6 +176,53 @@ impl Default for Generator { } } +#[derive(Debug, Clone)] +pub(crate) struct Security { + pub(crate) per_path: IndexMap, + /// The global available security requirements, the defaults, named, which must be ref'd by overrides + pub(crate) global: Vec, + /// Declares the scheme and which header to use, each one referenced above must exist in the below + pub(crate) schemes: IndexMap, +} + +// usuful to derive the global fallback, if any +impl Security { + pub(crate) fn resolve_for_path(&self, path: &PathTemplate) -> Option { + let path = path.to_string(); + let requirements = self.per_path.get(&path).cloned().or_else(|| { self.global.first().cloned() })?; + let mut schemes = + requirements + .iter() + .map(|(name,requirements)| { + self.schemes + .get(name) + .expect("Contains that name, otherwise spec is buggy. qed") + }); + // TODO let's start with exactly one or zero schemes + assert!(schemes.len() <= 1); + schemes.next().cloned() + } + + pub(crate) fn from(spec: &OpenAPI) -> Self { + Self { + per_path: Default::default(), // TODO + global: spec.security.clone().unwrap_or_default(), + schemes: spec.components.as_ref().map(|c| { + IndexMap::from_iter( + c.security_schemes + .iter() + .map(|(key, reference_or_sec_scheme)| { + ( + key.to_owned(), + ReferenceOrExt::item(reference_or_sec_scheme, &spec.components).cloned().expect("Spec was checked for validity, so this must work. qed") + ) + }) + ) + }).unwrap_or_default(), + } + } +} + impl Generator { pub fn new(settings: &GenerationSettings) -> Self { let mut type_settings = TypeSpaceSettings::default(); @@ -204,14 +257,18 @@ impl Generator { Self { type_space: TypeSpace::new(&type_settings), settings: settings.clone(), + forms: Default::default(), uses_futures: false, uses_websockets: false, } } + /// Generate the actual rust implementation from the specification pub fn generate_tokens(&mut self, spec: &OpenAPI) -> Result { validate_openapi(spec)?; + let security = Security::from(&spec); + // Convert our components dictionary to schemars let schemas = spec.components.iter().flat_map(|components| { components.schemas.iter().map(|(name, ref_or_schema)| { @@ -237,6 +294,7 @@ impl Generator { &spec.components, path, method, + &security, path_parameters, ) }) @@ -262,6 +320,51 @@ impl Generator { let types = self.type_space.to_stream(); + let extra_impl = TokenStream::from_iter( + self.forms + .iter() + .map(|type_id| { + let typ = self.get_type_space().get_type(type_id).unwrap(); + let form_name = typ.name(); + let td = typ.details(); + let TypeDetails::Struct(tstru) = td else { unreachable!() }; + let properties = indexmap::IndexMap::<&'_ str, _>::from_iter( + tstru + .properties() + .filter_map(|(prop_name, prop_id)| { + self.get_type_space() + .get_type(&prop_id).ok() + .map(|prop_typ| (prop_name, prop_typ)) + }) + ); + let properties = syn::punctuated::Punctuated::<_, syn::Token![,]>::from_iter( + properties + .into_iter() + .map(|(prop_name, prop_ty)| { + let ident = quote::format_ident!("{}", prop_name); + quote!{ (#prop_name, &self. #ident) } + })); + + let form_name = quote::format_ident!("{}",typ.name()); + + quote! { + impl #form_name { + pub fn as_form<'f>(&'f self) -> impl std::iter::Iterator + 'f { + [#properties] + .into_iter() + .filter_map(|(name, val)|{ + val.as_ref().map(|val| (name, val)) + }) + .map(|(name, val)| { + let part = reqwest::multipart::Part::stream(val.to_vec()) + .file_name("sortme.pdf".to_owned()) // required for sevdesk, for "validation" + .mime_str("application/pdf").unwrap(); + (name, part) + }) + } + } + } + })); // Generate an implementation of a `Self::as_inner` method, if an inner // type is defined. let maybe_inner = self.settings.inner_type.as_ref().map(|inner| { @@ -290,20 +393,20 @@ impl Generator { }); let client_docstring = { - let mut s = format!("Client for {}", spec.info.title); + let mut doc = format!("Client for {}", spec.info.title); - if let Some(ss) = &spec.info.description { - s.push_str("\n\n"); - s.push_str(ss); + if let Some(desc) = &spec.info.description { + doc.push_str("\n\n"); + doc.push_str(desc); } - if let Some(ss) = &spec.info.terms_of_service { - s.push_str("\n\n"); - s.push_str(ss); + if let Some(tos) = &spec.info.terms_of_service { + doc.push_str("\n\n"); + doc.push_str(tos); } - s.push_str(&format!("\n\nVersion: {}", &spec.info.version)); + doc.push_str(&format!("\n\nVersion: {}", &spec.info.version)); - s + doc }; let version_str = &spec.info.version; @@ -325,6 +428,8 @@ impl Generator { use std::convert::TryFrom; #types + + #extra_impl } #[derive(Clone, Debug)] @@ -409,7 +514,7 @@ impl Generator { .collect::>>()?; let out = quote! { impl Client { - #(#methods)* + #( #methods )* } pub mod prelude { diff --git a/progenitor-impl/src/method.rs b/progenitor-impl/src/method.rs index 9ab4dc91..d3c76366 100644 --- a/progenitor-impl/src/method.rs +++ b/progenitor-impl/src/method.rs @@ -2,11 +2,12 @@ use std::{ cmp::Ordering, - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashMap}, str::FromStr, }; -use openapiv3::{Components, Parameter, ReferenceOr, Response, StatusCode}; +use indexmap::IndexSet; +use openapiv3::{Components, Parameter, ReferenceOr, Response, StatusCode, OpenAPI, APIKeyLocation, SecurityScheme}; use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; use typify::{TypeId, TypeSpace}; @@ -14,16 +15,18 @@ use typify::{TypeId, TypeSpace}; use crate::{ template::PathTemplate, util::{items, parameter_map, sanitize, Case}, - Error, Generator, Result, TagStyle, + Error, Generator, Result, TagStyle, Security, }; use crate::{to_schema::ToSchema, util::ReferenceOrExt}; + /// The intermediate representation of an operation that will become a method. pub(crate) struct OperationMethod { pub operation_id: String, pub tags: Vec, method: HttpMethod, path: PathTemplate, + pub security: Security, pub summary: Option, pub description: Option, pub params: Vec, @@ -101,9 +104,10 @@ pub struct OperationParameter { pub kind: OperationParameterKind, } -#[derive(Eq, PartialEq)] +#[derive(Debug, Eq, PartialEq)] pub enum OperationParameterType { Type(TypeId), + Form(TypeId), RawBody, } @@ -120,6 +124,7 @@ pub enum BodyContentType { OctetStream, Json, FormUrlencoded, + FormData, } impl FromStr for BodyContentType { @@ -130,6 +135,7 @@ impl FromStr for BodyContentType { "application/octet-stream" => Ok(Self::OctetStream), "application/json" => Ok(Self::Json), "application/x-www-form-urlencoded" => Ok(Self::FormUrlencoded), + "form-data" | "multipart/form-data" => Ok(Self::FormData), _ => Err(Error::UnexpectedFormat(format!( "unexpected content type: {}", s @@ -138,10 +144,24 @@ impl FromStr for BodyContentType { } } +use std::fmt; + +impl fmt::Display for BodyContentType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Self::OctetStream => "application/octet-stream", + Self::Json => "application/json", + Self::FormUrlencoded => "application/x-www-form-urlencoded", + Self::FormData => "multipart/form-data", + }) + } +} + #[derive(Debug)] pub(crate) struct OperationResponse { status_code: OperationResponseStatus, typ: OperationResponseType, + format: Option, // TODO this isn't currently used because dropshot doesn't give us a // particularly useful message here. #[allow(dead_code)] @@ -223,6 +243,23 @@ impl PartialOrd for OperationResponseStatus { } } +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub(crate) enum OperationResponseFormat { + JSON, + XML, + FormUrlencoded, +} + +impl std::fmt::Display for OperationResponseFormat { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Self::JSON => "application/json;charset=UTF-8", + Self::XML => "application/xml;charset=UTF-8", + Self::FormUrlencoded => "application/x-www-form-urlencoded", + }) + } +} + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] pub(crate) enum OperationResponseType { Type(TypeId), @@ -258,6 +295,7 @@ impl Generator { components: &Option, path: &str, method: &str, + security: &Security, path_parameters: &[ReferenceOr], ) -> Result { let operation_id = operation.operation_id.as_ref().unwrap(); @@ -328,7 +366,7 @@ impl Generator { let typ = self .type_space - .add_type_with_name(&schema, Some(name))?; + .add_type_with_name(&schema, Some(dbg!(name)))?; query.push(( parameter_data.name.clone(), @@ -454,7 +492,7 @@ impl Generator { // enum; the generated client method would check for the // content type of the response just as it currently examines // the status code. - let typ = if let Some(mt) = + let (typ, format ) = if let Some(mt) = response.content.get("application/json") { assert!(mt.encoding.is_empty()); @@ -474,13 +512,13 @@ impl Generator { todo!("media type encoding, no schema: {:#?}", mt); }; - OperationResponseType::Type(typ) + (OperationResponseType::Type(typ), Some(OperationResponseFormat::JSON)) } else if dropshot_websocket { - OperationResponseType::Upgrade + (OperationResponseType::Upgrade, None) } else if response.content.first().is_some() { - OperationResponseType::Raw + (OperationResponseType::Raw, None) } else { - OperationResponseType::None + (OperationResponseType::None, None) }; // See if there's a status code that covers success cases. @@ -502,6 +540,7 @@ impl Generator { Ok(OperationResponse { status_code, typ, + format, description, }) }) @@ -515,6 +554,7 @@ impl Generator { responses.push(OperationResponse { status_code: OperationResponseStatus::Range(2), typ: OperationResponseType::Raw, + format: None, description: None, }); } @@ -524,6 +564,7 @@ impl Generator { responses.push(OperationResponse { status_code: OperationResponseStatus::Code(101), typ: OperationResponseType::Upgrade, + format: None, description: None, }) } @@ -548,8 +589,9 @@ impl Generator { .description .clone() .filter(|s| !s.is_empty()), - params, responses, + params, + security: security.clone(), // TODO parse per request requirements instead of assuming the global default dropshot_paginated, dropshot_websocket, }) @@ -567,18 +609,19 @@ impl Generator { .iter() .map(|param| { let name = format_ident!("{}", param.name); - let typ = match ¶m.typ { - OperationParameterType::Type(type_id) => self - .type_space - .get_type(type_id) - .unwrap() - .parameter_ident_with_lifetime("a"), + match ¶m.typ { + OperationParameterType::Type(type_id) + | OperationParameterType::Form(type_id) => { + let typ = self + .type_space + .get_type(type_id) + .expect("TypeIDs are _never_ deleted. qed") + .parameter_ident_with_lifetime("a"); + quote! { #name: #typ} + } OperationParameterType::RawBody => { - quote! { B } + quote! { #name: B } } - }; - quote! { - #name: #typ } }) .collect::>(); @@ -759,7 +802,7 @@ impl Generator { client: TokenStream, ) -> Result { // Generate code for query parameters. - let query_items = method + let query_items = Vec::from_iter(method .params .iter() .filter_map(|param| match ¶m.kind { @@ -781,8 +824,7 @@ impl Generator { Some(res) } _ => None, - }) - .collect::>(); + })); let (query_build, query_use) = if query_items.is_empty() { (quote! {}, quote! {}) @@ -799,7 +841,7 @@ impl Generator { (query_build, query_use) }; - let headers = method + let mut headers = method .params .iter() .filter_map(|param| match ¶m.kind { @@ -822,15 +864,47 @@ impl Generator { _ => None, }) .collect::>(); - + + if let Some(sec_scheme) = method.security.resolve_for_path(&method.path) { + match sec_scheme { + SecurityScheme::APIKey { location: APIKeyLocation::Header, name, .. } => { + let hn = name; + headers.push(quote!{ + { + let value = self.inner.header_api_key(#hn); + header_map.append(#hn, HeaderValue::try_from(value)?); + } + }) + } + SecurityScheme::APIKey { location, name, .. } => { + } + SecurityScheme::HTTP { scheme, bearer_format, description } => { + todo!("Craft http header") + } + _ => eprintln!("Only header APIKeys are supported right now, others have to be impl'd manually"), + } + }; + + use itertools::Itertools; + + // Add "Accept" headers + let accepts = method.responses.iter().filter_map(|response| response.format.as_ref()).unique().map(|format| { + let format = format.to_string(); + quote! { + header_map.append("Accept", HeaderValue::from_static(#format)); + } + }); + headers.extend(accepts); + let (headers_build, headers_use) = if headers.is_empty() { (quote! {}, quote! {}) } else { let size = headers.len(); let headers_build = quote! { let mut header_map = HeaderMap::with_capacity(#size); - #(#headers)* + #( #headers )* }; + let headers_use = quote! { .headers(header_map) }; @@ -857,16 +931,15 @@ impl Generator { // Generate the path rename map; then use it to generate code for // assigning the path parameters to the `url` variable. - let url_renames = method - .params - .iter() - .filter_map(|param| match ¶m.kind { - OperationParameterKind::Path => { - Some((¶m.api_name, ¶m.name)) + let url_renames = + HashMap::from_iter(method.params.iter().filter_map(|param| { + match ¶m.kind { + OperationParameterKind::Path => { + Some((¶m.api_name, ¶m.name)) + } + _ => None, } - _ => None, - }) - .collect(); + })); let url_path = method.path.compile(url_renames, client.clone()); @@ -902,6 +975,17 @@ impl Generator { // returns an error in the case of a serialization failure. .form_urlencoded(&body)? }), + ( + OperationParameterKind::Body( + BodyContentType::FormData + ), + OperationParameterType::Form(_), + ) => { + Some(quote! { + // form data header is set automatically by our call to reqwest's `fn multipart(..)`. + // This uses progenitor_client::RequestBuilderExt which sets up a simple form data based on bytes + .form_from_raw(body.as_form())? + })}, (OperationParameterKind::Body(_), _) => { unreachable!("invalid body kind/type combination") } @@ -1042,7 +1126,7 @@ impl Generator { let request = #client.client . #method_func (url) - #(#body_func)* + #( #body_func )* #query_use #headers_use #websock_hdrs @@ -1148,6 +1232,7 @@ impl Generator { .next() // TODO should this be OperationResponseType::Raw? .unwrap_or(OperationResponseType::None); + (response_items, response_type) } @@ -1261,7 +1346,7 @@ impl Generator { .ok()? .details() { - typify::TypeDetails::Array(item) => { + typify::TypeDetails::Array(item, ..) => { Some(DropshotPagination { item }) } _ => None, @@ -1365,7 +1450,8 @@ impl Generator { .params .iter() .map(|param| match ¶m.typ { - OperationParameterType::Type(type_id) => { + OperationParameterType::Type(type_id) + | OperationParameterType::Form(type_id) => { let ty = self.type_space.get_type(type_id)?; // For body parameters only, if there's a builder we'll @@ -1394,7 +1480,8 @@ impl Generator { .params .iter() .map(|param| match ¶m.typ { - OperationParameterType::Type(type_id) => { + OperationParameterType::Type(type_id) + | OperationParameterType::Form(type_id) => { let ty = self.type_space.get_type(type_id)?; let details = ty.details(); let optional = @@ -1424,7 +1511,8 @@ impl Generator { .params .iter() .map(|param| match ¶m.typ { - OperationParameterType::Type(type_id) => { + OperationParameterType::Type(type_id) + | OperationParameterType::Form(type_id) => { let ty = self.type_space.get_type(type_id)?; if ty.builder().is_some() { let type_name = ty.ident(); @@ -1437,6 +1525,7 @@ impl Generator { Ok(quote! {}) } } + OperationParameterType::RawBody => Ok(quote! {}), }) .collect::>>()?; @@ -1449,11 +1538,12 @@ impl Generator { .map(|param| { let param_name = format_ident!("{}", param.name); match ¶m.typ { - OperationParameterType::Type(type_id) => { + OperationParameterType::Type(type_id) + | OperationParameterType::Form(type_id) => { let ty = self.type_space.get_type(type_id)?; let details = ty.details(); match (&details, ty.builder()) { - // TODO right now optional body paramters are not + // TODO right now optional body parameters are not // addressed (typify::TypeDetails::Option(_), Some(_)) => { unreachable!() @@ -1537,6 +1627,23 @@ impl Generator { } } + OperationParameterType::Form(_type_id) => { + let err_msg = format!( + "conversion to `reqwest::Body` for {} failed", + param.name, + ); + + Ok(quote! { + pub fn #param_name(mut self, value: B) -> Self + where B: std::convert::TryInto + { + self.#param_name = value.try_into() + .map_err(|_| #err_msg.to_string()); + self + } + }) + } + OperationParameterType::RawBody => { let err_msg = format!( "conversion to `reqwest::Body` for {} failed", @@ -2000,12 +2107,105 @@ impl Generator { )), } if enumeration.is_empty() => Ok(()), _ => Err(Error::UnexpectedFormat(format!( - "invalid schema for application/octet-stream: {:?}", + "invalid schema for {}: {:?}", + BodyContentType::OctetStream, schema ))), }?; OperationParameterType::RawBody } + BodyContentType::FormData => { + // For form data, we expect a key-value set of types, specific schema: + + // ```yaml + // type: "object" + // properties: + // file: + // description: "The file to upload" + // type: "string" + // format: "binary" + // ``` + // "schema": { + // "type": "string", + // "format": "binary" + // } + + let mapped = match schema.item(components)? { + openapiv3::Schema { + schema_data: + openapiv3::SchemaData { + nullable: false, + discriminator: None, + default: None, + // Other fields that describe or document the + // schema are fine. + .. + }, + schema_kind: + openapiv3::SchemaKind::Type(openapiv3::Type::Object( + openapiv3::ObjectType { + properties, + additional_properties, + .. + }, + )), + } => { + let mapped = Result::>::from_iter( + properties + .into_iter() + .map(|(name, property)| { + // properties must be plain key value types for now + let ReferenceOr::Item(property) = property else { + return Err(Error::UnexpectedFormat(format!( + "invalid schema for {}: didn't expect a reference", + BodyContentType::FormData, + )))}; + match &property.schema_kind { + openapiv3::SchemaKind::Type(openapiv3::Type::String( + openapiv3::StringType { + format: + openapiv3::VariantOrUnknownOrEmpty::Item( + openapiv3::StringFormat::Binary, + ), + pattern: None, + enumeration, + min_length: None, + max_length: None, + }, + )) if enumeration.is_empty() => { + Ok(name.to_owned()) + } + schema => { + Err(Error::UnexpectedFormat(format!( + "invalid schema for {}: {:?}", + BodyContentType::FormData, + schema + ))) + } + } + }))?; + Ok(mapped) + } + _ => Err(Error::UnexpectedFormat(format!( + "invalid schema for {}: {:?}", + BodyContentType::FormData, + schema + ))), + }?; + + let form_name = sanitize( + &format!( + "{}-form", + operation.operation_id.as_ref().unwrap(), + ), + Case::Pascal, + ); + let type_id = self + .type_space + .add_type_with_name(&schema.to_schema(), Some(form_name))?; + self.forms.insert(type_id.clone()); + OperationParameterType::Form(type_id) + } BodyContentType::Json | BodyContentType::FormUrlencoded => { // TODO it would be legal to have the encoding field set for // application/x-www-form-urlencoded content, but I'm not sure diff --git a/progenitor-impl/src/template.rs b/progenitor-impl/src/template.rs index 98db6085..27a12fca 100644 --- a/progenitor-impl/src/template.rs +++ b/progenitor-impl/src/template.rs @@ -51,6 +51,7 @@ impl PathTemplate { quote! { let url = format!(#fmt, #client.baseurl, #(#components,)*); + println!("Querying: {}", &url); } } diff --git a/progenitor-impl/src/util.rs b/progenitor-impl/src/util.rs index 4d4d2eba..d31b0114 100644 --- a/progenitor-impl/src/util.rs +++ b/progenitor-impl/src/util.rs @@ -4,7 +4,7 @@ use std::collections::BTreeMap; use indexmap::IndexMap; use openapiv3::{ - Components, Parameter, ReferenceOr, RequestBody, Response, Schema, + Components, Parameter, ReferenceOr, RequestBody, Response, Schema, SecurityScheme, }; use unicode_ident::{is_xid_continue, is_xid_start}; @@ -51,9 +51,9 @@ pub(crate) fn parameter_map<'a>( refs: &'a [ReferenceOr], components: &'a Option, ) -> Result> { - items(refs, components) + Result::from_iter(items(refs, components) .map(|res| res.map(|param| (¶m.parameter_data_ref().name, param))) - .collect() + ) } impl ComponentLookup for Parameter { @@ -88,6 +88,14 @@ impl ComponentLookup for Schema { } } +impl ComponentLookup for SecurityScheme { + fn get_components( + components: &Components, + ) -> &IndexMap> { + &components.security_schemes + } +} + pub(crate) enum Case { Pascal, Snake, diff --git a/progenitor/Cargo.toml b/progenitor/Cargo.toml index 07814993..6fbbbe98 100644 --- a/progenitor/Cargo.toml +++ b/progenitor/Cargo.toml @@ -32,6 +32,6 @@ futures = "0.3.27" percent-encoding = "2.2" rand = "0.8" regress = "0.5.0" -reqwest = { version = "0.11.16", features = ["json", "stream"] } +reqwest = { version = "0.11.16", features = ["json", "stream", "multipart"] } schemars = { version = "0.8.12", features = ["uuid1"] } uuid = { version = "1.3", features = ["serde", "v4"] } diff --git a/progenitor/src/main.rs b/progenitor/src/main.rs index 6e1397c5..020c1597 100644 --- a/progenitor/src/main.rs +++ b/progenitor/src/main.rs @@ -212,7 +212,7 @@ pub fn dependencies(builder: Generator, include_client: bool) -> Vec { let mut deps = vec![ format!("bytes = \"{}\"", dependency_versions.get("bytes").unwrap()), format!("futures-core = \"{}\"", dependency_versions.get("futures-core").unwrap()), - format!("reqwest = {{ version = \"{}\", default-features=false, features = [\"json\", \"stream\"] }}", dependency_versions.get("reqwest").unwrap()), + format!("reqwest = {{ version = \"{}\", default-features=false, features = [\"json\", \"stream\", \"multipart\"] }}", dependency_versions.get("reqwest").unwrap()), format!("serde = {{ version = \"{}\", features = [\"derive\"] }}", dependency_versions.get("serde").unwrap()), format!("serde_urlencoded = \"{}\"", dependency_versions.get("serde_urlencoded").unwrap()), ];