Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ impl<'b> CodeGenerator<'_, 'b> {
}

fn append_field(&mut self, fq_message_name: &str, field: &Field) {
let type_ = field.descriptor.r#type();
let repeated = field.descriptor.label() == Label::Repeated;
let type_ = field.descriptor.type_or_default();
let repeated = field.descriptor.label_or_default() == Label::Repeated;
let deprecated = self.deprecated(&field.descriptor);
let optional = self.optional(&field.descriptor);
let boxed = self
Expand Down Expand Up @@ -442,7 +442,7 @@ impl<'b> CodeGenerator<'_, 'b> {
.push_str(&format!(" = {:?}", bytes_type.annotation()));
}

match field.descriptor.label() {
match field.descriptor.label_or_default() {
Label::Optional => {
if optional {
self.buf.push_str(", optional");
Expand Down Expand Up @@ -946,7 +946,7 @@ impl<'b> CodeGenerator<'_, 'b> {
}

fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String {
match field.r#type() {
match field.type_or_default() {
Type::Float => String::from("f32"),
Type::Double => String::from("f64"),
Type::Uint32 | Type::Fixed32 => String::from("u32"),
Expand Down Expand Up @@ -1003,7 +1003,7 @@ impl<'b> CodeGenerator<'_, 'b> {
}

fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
match field.r#type() {
match field.type_or_default() {
Type::Float => Cow::Borrowed("float"),
Type::Double => Cow::Borrowed("double"),
Type::Int32 => Cow::Borrowed("int32"),
Expand All @@ -1029,7 +1029,7 @@ impl<'b> CodeGenerator<'_, 'b> {
}

fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
match field.r#type() {
match field.type_or_default() {
Type::Enum => Cow::Owned(format!(
"enumeration({})",
self.resolve_ident(field.type_name())
Expand All @@ -1043,11 +1043,11 @@ impl<'b> CodeGenerator<'_, 'b> {
return true;
}

if field.label() != Label::Optional {
if field.label_or_default() != Label::Optional {
return false;
}

match field.r#type() {
match field.type_or_default() {
Type::Message => true,
_ => self.syntax == Syntax::Proto2,
}
Expand All @@ -1074,7 +1074,7 @@ impl<'b> CodeGenerator<'_, 'b> {
/// Returns `true` if the repeated field type can be packed.
fn can_pack(field: &FieldDescriptorProto) -> bool {
matches!(
field.r#type(),
field.type_or_default(),
Type::Float
| Type::Double
| Type::Int32
Expand Down
16 changes: 8 additions & 8 deletions prost-build/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ impl<'a> Context<'a> {
oneof: Option<&str>,
field: &FieldDescriptorProto,
) -> bool {
if field.label() == Label::Repeated {
if field.label_or_default() == Label::Repeated {
// Repeated field are stored in Vec, therefore it is already heap allocated
return false;
}
let fd_type = field.r#type();
let fd_type = field.type_or_default();
if (fd_type == Type::Message || fd_type == Type::Group)
&& self
.message_graph
Expand Down Expand Up @@ -188,9 +188,9 @@ impl<'a> Context<'a> {
assert_eq!(".", &fq_message_name[..1]);

// repeated field cannot derive Copy
if field.label() == Label::Repeated {
if field.label_or_default() == Label::Repeated {
false
} else if field.r#type() == Type::Message {
} else if field.type_or_default() == Type::Message {
// nested and boxed messages cannot derive Copy
if self
.message_graph
Expand All @@ -210,7 +210,7 @@ impl<'a> Context<'a> {
}
} else {
matches!(
field.r#type(),
field.type_or_default(),
Type::Float
| Type::Double
| Type::Int32
Expand Down Expand Up @@ -243,8 +243,8 @@ impl<'a> Context<'a> {
pub fn can_field_derive_eq(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool {
assert_eq!(".", &fq_message_name[..1]);

if field.r#type() == Type::Message {
if field.label() == Label::Repeated
if field.type_or_default() == Type::Message {
if field.label_or_default() == Label::Repeated
|| self
.message_graph
.is_nested(field.type_name(), fq_message_name)
Expand All @@ -255,7 +255,7 @@ impl<'a> Context<'a> {
}
} else {
matches!(
field.r#type(),
field.type_or_default(),
Type::Int32
| Type::Int64
| Type::Uint32
Expand Down
4 changes: 3 additions & 1 deletion prost-build/src/message_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ impl MessageGraph {
let msg_index = self.get_or_insert_index(msg_name.clone());

for field in &msg.field {
if field.r#type() == Type::Message && field.label() != Label::Repeated {
if field.type_or_default() == Type::Message
&& field.label_or_default() != Label::Repeated
{
let field_index = self.get_or_insert_index(field.type_name.clone().unwrap());
self.graph.add_edge(msg_index, field_index, ());
}
Expand Down
25 changes: 22 additions & 3 deletions prost-derive/src/field/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::fmt;

use anyhow::{anyhow, bail, Error};
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens, TokenStreamExt};
use quote::{format_ident, quote, ToTokens, TokenStreamExt};
use syn::{parse_str, Expr, ExprLit, Ident, Index, Lit, LitByteStr, Meta, MetaNameValue, Path};

use crate::field::{bool_attr, set_option, tag_attr, Label};
Expand Down Expand Up @@ -283,6 +283,16 @@ impl Field {
}
Err(_) => quote!(#ident),
};
let get_or_default = match syn::parse_str::<Index>(&ident_str) {
Ok(index) => {
let get = Ident::new(
&format!("get_{}_or_default", index.index),
Span::call_site(),
);
quote!(#get)
}
Err(_) => format_ident!("{ident}_or_default").to_token_stream(),
};

if let Ty::Enumeration(ref ty) = self.ty {
let set = Ident::new(&format!("set_{ident_str}"), Span::call_site());
Expand All @@ -307,16 +317,25 @@ impl Field {
}
Kind::Optional(ref default) => {
let get_doc = format!(
"Returns the enum value of `{ident_str}`, \
or `None` if the field is unset or set to an invalid enum value."
);
let get_or_default_doc = format!(
"Returns the enum value of `{ident_str}`, \
or the default if the field is unset or set to an invalid enum value."
);
quote! {
#[doc=#get_doc]
pub fn #get(&self) -> #ty {
pub fn #get(&self) -> ::core::option::Option<#ty> {
self.#ident.and_then(|x| {
let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
result.ok()
}).unwrap_or(#default)
})
}

#[doc=#get_or_default_doc]
pub fn #get_or_default(&self) -> #ty {
self.#get().unwrap_or(#default)
}

#[doc=#set_doc]
Expand Down
11 changes: 7 additions & 4 deletions tests/src/default_enum_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@ include!(concat!(env!("OUT_DIR"), "/default_enum_value.rs"));
#[test]
fn test_default_enum() {
let msg = Test::default();
assert_eq!(msg.privacy_level_1(), PrivacyLevel::One);
assert_eq!(msg.privacy_level_3(), PrivacyLevel::PrivacyLevelThree);
assert_eq!(msg.privacy_level_1_or_default(), PrivacyLevel::One);
assert_eq!(
msg.privacy_level_4(),
msg.privacy_level_3_or_default(),
PrivacyLevel::PrivacyLevelThree
);
assert_eq!(
msg.privacy_level_4_or_default(),
PrivacyLevel::PrivacyLevelprivacyLevelFour
);

let msg = CMsgRemoteClientBroadcastHeader::default();
assert_eq!(
msg.msg_type(),
msg.msg_type_or_default(),
ERemoteClientBroadcastMsg::KERemoteClientBroadcastMsgDiscovery
);
}
Expand Down
Loading