tinc_build/
lib.rs

1//! The code generator for [`tinc`](https://crates.io/crates/tinc).
2#![cfg_attr(feature = "docs", doc = "## Feature flags")]
3#![cfg_attr(feature = "docs", doc = document_features::document_features!())]
4//! ## Usage
5//!
6//! In your `build.rs`:
7//!
8//! ```rust,no_run
9//! # #[allow(clippy::needless_doctest_main)]
10//! fn main() {
11//!     tinc_build::Config::prost()
12//!         .compile_protos(&["proto/test.proto"], &["proto"])
13//!         .unwrap();
14//! }
15//! ```
16//!
17//! Look at [`Config`] to see different options to configure the generator.
18//!
19//! ## License
20//!
21//! This project is licensed under the MIT or Apache-2.0 license.
22//! You can choose between one of them if you use this work.
23//!
24//! `SPDX-License-Identifier: MIT OR Apache-2.0`
25#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
26#![cfg_attr(docsrs, feature(doc_auto_cfg))]
27#![deny(missing_docs)]
28#![deny(unsafe_code)]
29#![deny(unreachable_pub)]
30#![deny(clippy::mod_module_files)]
31#![cfg_attr(not(feature = "prost"), allow(unused_variables, dead_code))]
32
33use std::io::ErrorKind;
34use std::path::{Path, PathBuf};
35
36use anyhow::Context;
37use extern_paths::ExternPaths;
38
39use crate::path_set::PathSet;
40
41mod codegen;
42mod extern_paths;
43mod path_set;
44
45#[cfg(feature = "prost")]
46mod prost_explore;
47
48mod types;
49
50/// The mode to use for the generator, currently we only support `prost` codegen.
51#[derive(Debug, Clone, Copy)]
52pub enum Mode {
53    /// Use `prost` to generate the protobuf structures
54    #[cfg(feature = "prost")]
55    Prost,
56}
57
58impl quote::ToTokens for Mode {
59    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
60        match self {
61            #[cfg(feature = "prost")]
62            Mode::Prost => quote::quote!(prost).to_tokens(tokens),
63            #[cfg(not(feature = "prost"))]
64            _ => unreachable!(),
65        }
66    }
67}
68
69#[derive(Default, Debug)]
70struct PathConfigs {
71    btree_maps: Vec<String>,
72    bytes: Vec<String>,
73    boxed: Vec<String>,
74    floats_with_non_finite_vals: PathSet,
75}
76
77/// A config for configuring how tinc builds / generates code.
78#[derive(Debug)]
79pub struct Config {
80    disable_tinc_include: bool,
81    root_module: bool,
82    mode: Mode,
83    paths: PathConfigs,
84    extern_paths: ExternPaths,
85    out_dir: PathBuf,
86}
87
88impl Config {
89    /// New config with prost mode.
90    #[cfg(feature = "prost")]
91    pub fn prost() -> Self {
92        Self::new(Mode::Prost)
93    }
94
95    /// Make a new config with a given mode.
96    pub fn new(mode: Mode) -> Self {
97        Self::new_with_out_dir(mode, std::env::var_os("OUT_DIR").expect("OUT_DIR not set"))
98    }
99
100    /// Make a new config with a given mode.
101    pub fn new_with_out_dir(mode: Mode, out_dir: impl Into<PathBuf>) -> Self {
102        Self {
103            disable_tinc_include: false,
104            mode,
105            paths: PathConfigs::default(),
106            extern_paths: ExternPaths::new(mode),
107            root_module: true,
108            out_dir: out_dir.into(),
109        }
110    }
111
112    /// Disable tinc auto-include. By default tinc will add its own
113    /// annotations into the include path of protoc.
114    pub fn disable_tinc_include(&mut self) -> &mut Self {
115        self.disable_tinc_include = true;
116        self
117    }
118
119    /// Disable the root module generation
120    /// which allows for `tinc::include_protos!()` without
121    /// providing a package.
122    pub fn disable_root_module(&mut self) -> &mut Self {
123        self.root_module = false;
124        self
125    }
126
127    /// Specify a path to generate a `BTreeMap` instead of a `HashMap` for proto map.
128    pub fn btree_map(&mut self, path: impl std::fmt::Display) -> &mut Self {
129        self.paths.btree_maps.push(path.to_string());
130        self
131    }
132
133    /// Specify a path to generate `bytes::Bytes` instead of `Vec<u8>` for proto bytes.
134    pub fn bytes(&mut self, path: impl std::fmt::Display) -> &mut Self {
135        self.paths.bytes.push(path.to_string());
136        self
137    }
138
139    /// Specify a path to wrap around a `Box` instead of including it directly into the struct.
140    pub fn boxed(&mut self, path: impl std::fmt::Display) -> &mut Self {
141        self.paths.boxed.push(path.to_string());
142        self
143    }
144
145    /// Specify a path to float/double field (or derivative, like repeated float/double)
146    /// that must use serializer/deserializer with non-finite values support (NaN/Infinity).
147    pub fn float_with_non_finite_vals(&mut self, path: impl std::fmt::Display) -> &mut Self {
148        self.paths.floats_with_non_finite_vals.insert(path);
149        self
150    }
151
152    /// Compile and generate all the protos with the includes.
153    pub fn compile_protos(&mut self, protos: &[impl AsRef<Path>], includes: &[impl AsRef<Path>]) -> anyhow::Result<()> {
154        match self.mode {
155            #[cfg(feature = "prost")]
156            Mode::Prost => self.compile_protos_prost(protos, includes),
157        }
158    }
159
160    /// Generate tinc code based on a precompiled FileDescriptorSet.
161    pub fn load_fds(&mut self, fds: impl bytes::Buf) -> anyhow::Result<()> {
162        match self.mode {
163            #[cfg(feature = "prost")]
164            Mode::Prost => self.load_fds_prost(fds),
165        }
166    }
167
168    #[cfg(feature = "prost")]
169    fn compile_protos_prost(&mut self, protos: &[impl AsRef<Path>], includes: &[impl AsRef<Path>]) -> anyhow::Result<()> {
170        let fd_path = self.out_dir.join("tinc.fd.bin");
171
172        let mut config = prost_build::Config::new();
173        config.file_descriptor_set_path(&fd_path);
174
175        let mut includes = includes.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
176
177        {
178            let tinc_out = self.out_dir.join("tinc");
179            std::fs::create_dir_all(&tinc_out).context("failed to create tinc directory")?;
180            std::fs::write(tinc_out.join("annotations.proto"), tinc_pb_prost::TINC_ANNOTATIONS)
181                .context("failed to write tinc_annotations.rs")?;
182            includes.push(&self.out_dir);
183        }
184
185        config.load_fds(protos, &includes).context("failed to generate tonic fds")?;
186        let fds_bytes = std::fs::read(fd_path).context("failed to read tonic fds")?;
187        self.load_fds_prost(fds_bytes.as_slice())
188    }
189
190    #[cfg(feature = "prost")]
191    fn load_fds_prost(&mut self, fds: impl bytes::Buf) -> anyhow::Result<()> {
192        use std::collections::BTreeMap;
193
194        use codegen::prost_sanatize::to_snake;
195        use codegen::utils::get_common_import_path;
196        use proc_macro2::Span;
197        use prost::Message;
198        use prost_reflect::DescriptorPool;
199        use prost_types::FileDescriptorSet;
200        use quote::{ToTokens, quote};
201        use syn::parse_quote;
202        use types::{ProtoPath, ProtoTypeRegistry};
203
204        let pool = DescriptorPool::decode(fds).context("failed to add tonic fds")?;
205
206        let mut registry = ProtoTypeRegistry::new(
207            self.mode,
208            self.extern_paths.clone(),
209            self.paths.floats_with_non_finite_vals.clone(),
210        );
211
212        let mut config = prost_build::Config::new();
213
214        // This option is provided to make sure prost_build does not internally
215        // set extern_paths. We manage that via a re-export of prost_types in the
216        // tinc crate.
217        config.compile_well_known_types();
218
219        config.btree_map(self.paths.btree_maps.iter());
220        self.paths.boxed.iter().for_each(|path| {
221            config.boxed(path);
222        });
223        config.bytes(self.paths.bytes.iter());
224
225        for (proto, rust) in self.extern_paths.paths() {
226            let proto = if proto.starts_with('.') {
227                proto.to_string()
228            } else {
229                format!(".{proto}")
230            };
231            config.extern_path(proto, rust.to_token_stream().to_string());
232        }
233
234        prost_explore::Extensions::new(&pool)
235            .process(&mut registry)
236            .context("failed to process extensions")?;
237
238        let mut packages = codegen::generate_modules(&registry)?;
239
240        packages.iter_mut().for_each(|(path, package)| {
241            if self.extern_paths.contains(path) {
242                return;
243            }
244
245            package.enum_configs().for_each(|(path, enum_config)| {
246                if self.extern_paths.contains(path) {
247                    return;
248                }
249
250                enum_config.attributes().for_each(|attribute| {
251                    config.enum_attribute(path, attribute.to_token_stream().to_string());
252                });
253                enum_config.variants().for_each(|variant| {
254                    let path = format!("{path}.{variant}");
255                    enum_config.variant_attributes(variant).for_each(|attribute| {
256                        config.field_attribute(&path, attribute.to_token_stream().to_string());
257                    });
258                });
259            });
260
261            package.message_configs().for_each(|(path, message_config)| {
262                if self.extern_paths.contains(path) {
263                    return;
264                }
265
266                message_config.attributes().for_each(|attribute| {
267                    config.message_attribute(path, attribute.to_token_stream().to_string());
268                });
269                message_config.fields().for_each(|field| {
270                    let path = format!("{path}.{field}");
271                    message_config.field_attributes(field).for_each(|attribute| {
272                        config.field_attribute(&path, attribute.to_token_stream().to_string());
273                    });
274                });
275                message_config.oneof_configs().for_each(|(field, oneof_config)| {
276                    let path = format!("{path}.{field}");
277                    oneof_config.attributes().for_each(|attribute| {
278                        // In prost oneofs (container) are treated as enums
279                        config.enum_attribute(&path, attribute.to_token_stream().to_string());
280                    });
281                    oneof_config.fields().for_each(|field| {
282                        let path = format!("{path}.{field}");
283                        oneof_config.field_attributes(field).for_each(|attribute| {
284                            config.field_attribute(&path, attribute.to_token_stream().to_string());
285                        });
286                    });
287                });
288            });
289
290            package.extra_items.extend(package.services.iter().flat_map(|service| {
291                let mut builder = tonic_build::CodeGenBuilder::new();
292
293                builder.emit_package(true).build_transport(true);
294
295                let make_service = |is_client: bool| {
296                    let mut builder = tonic_build::manual::Service::builder()
297                        .name(service.name())
298                        .package(&service.package);
299
300                    if !service.comments.is_empty() {
301                        builder = builder.comment(service.comments.to_string());
302                    }
303
304                    service
305                        .methods
306                        .iter()
307                        .fold(builder, |service_builder, (name, method)| {
308                            let codec_path =
309                                if let Some(Some(codec_path)) = (!is_client).then_some(method.codec_path.as_ref()) {
310                                    let path = get_common_import_path(&service.full_name, codec_path);
311                                    quote!(#path::<::tinc::reexports::tonic_prost::ProstCodec<_, _>>)
312                                } else {
313                                    quote!(::tinc::reexports::tonic_prost::ProstCodec)
314                                };
315
316                            let mut builder = tonic_build::manual::Method::builder()
317                                .input_type(
318                                    registry
319                                        .resolve_rust_path(&service.full_name, method.input.value_type().proto_path())
320                                        .unwrap()
321                                        .to_token_stream()
322                                        .to_string(),
323                                )
324                                .output_type(
325                                    registry
326                                        .resolve_rust_path(&service.full_name, method.output.value_type().proto_path())
327                                        .unwrap()
328                                        .to_token_stream()
329                                        .to_string(),
330                                )
331                                .codec_path(codec_path.to_string())
332                                .name(to_snake(name))
333                                .route_name(name);
334
335                            if method.input.is_stream() {
336                                builder = builder.client_streaming()
337                            }
338
339                            if method.output.is_stream() {
340                                builder = builder.server_streaming();
341                            }
342
343                            if !method.comments.is_empty() {
344                                builder = builder.comment(method.comments.to_string());
345                            }
346
347                            service_builder.method(builder.build())
348                        })
349                        .build()
350                };
351
352                let mut client: syn::ItemMod = syn::parse2(builder.generate_client(&make_service(true), "")).unwrap();
353                client.content.as_mut().unwrap().1.insert(
354                    0,
355                    parse_quote!(
356                        use ::tinc::reexports::tonic;
357                    ),
358                );
359
360                let mut server: syn::ItemMod = syn::parse2(builder.generate_server(&make_service(false), "")).unwrap();
361                server.content.as_mut().unwrap().1.insert(
362                    0,
363                    parse_quote!(
364                        use ::tinc::reexports::tonic;
365                    ),
366                );
367
368                [client.into(), server.into()]
369            }));
370        });
371
372        for package in packages.keys() {
373            match std::fs::remove_file(self.out_dir.join(format!("{package}.rs"))) {
374                Err(err) if err.kind() != ErrorKind::NotFound => return Err(anyhow::anyhow!(err).context("remove")),
375                _ => {}
376            }
377        }
378
379        let fds = FileDescriptorSet {
380            file: pool.file_descriptor_protos().cloned().collect(),
381        };
382
383        let fd_path = self.out_dir.join("tinc.fd.bin");
384        std::fs::write(fd_path, fds.encode_to_vec()).context("write fds")?;
385
386        config.compile_fds(fds).context("prost compile")?;
387
388        for (package, module) in &mut packages {
389            if self.extern_paths.contains(package) {
390                continue;
391            };
392
393            let path = self.out_dir.join(format!("{package}.rs"));
394            write_module(&path, std::mem::take(&mut module.extra_items)).with_context(|| package.to_owned())?;
395        }
396
397        #[derive(Default)]
398        struct Module<'a> {
399            proto_path: Option<&'a ProtoPath>,
400            children: BTreeMap<&'a str, Module<'a>>,
401        }
402
403        impl ToTokens for Module<'_> {
404            fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
405                let include = self
406                    .proto_path
407                    .map(|p| p.as_ref())
408                    .map(|path| quote!(include!(concat!(#path, ".rs"));));
409                let children = self.children.iter().map(|(part, child)| {
410                    let ident = syn::Ident::new(&to_snake(part), Span::call_site());
411                    quote! {
412                        #[allow(clippy::all)]
413                        pub mod #ident {
414                            #child
415                        }
416                    }
417                });
418                quote! {
419                    #include
420                    #(#children)*
421                }
422                .to_tokens(tokens);
423            }
424        }
425
426        if self.root_module {
427            let mut module = Module::default();
428            for package in packages.keys() {
429                let mut module = &mut module;
430                for part in package.split('.') {
431                    module = module.children.entry(part).or_default();
432                }
433                module.proto_path = Some(package);
434            }
435
436            let file: syn::File = parse_quote!(#module);
437            std::fs::write(self.out_dir.join("___root_module.rs"), prettyplease::unparse(&file))
438                .context("write root module")?;
439        }
440
441        Ok(())
442    }
443}
444
445fn write_module(path: &std::path::Path, module: Vec<syn::Item>) -> anyhow::Result<()> {
446    let mut file = match std::fs::read_to_string(path) {
447        Ok(content) if !content.is_empty() => syn::parse_file(&content).context("parse")?,
448        Err(err) if err.kind() != ErrorKind::NotFound => return Err(anyhow::anyhow!(err).context("read")),
449        _ => syn::File {
450            attrs: Vec::new(),
451            items: Vec::new(),
452            shebang: None,
453        },
454    };
455
456    file.items.extend(module);
457    std::fs::write(path, prettyplease::unparse(&file)).context("write")?;
458
459    Ok(())
460}