scuffle_cedar_policy_codegen/
module.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::str::FromStr;
3
4use cedar_policy_core::ast::Id;
5
6use crate::cedar_action::CedarAction;
7use crate::codegen::Codegen;
8use crate::error::{CodegenError, CodegenResult};
9use crate::types::{CedarRef, CedarType, CedarTypeStructField, NamespaceId};
10use crate::utils::{find_relative_path, to_snake_ident, to_upper_camel_ident};
11
12/// Represents a module with its path and items
13#[derive(Default)]
14pub(crate) struct Module {
15    root_path: Vec<syn::Ident>,
16    items: Vec<syn::Item>,
17    sub_modules: BTreeMap<syn::Ident, Module>,
18}
19
20impl Module {
21    /// Gets or creates a sub-module
22    pub(crate) fn sub_module(&mut self, name: impl AsRef<str>) -> &mut Self {
23        let ident = to_snake_ident(name);
24        self.sub_modules.entry(ident.clone()).or_insert_with(|| Module {
25            root_path: {
26                let mut path = self.root_path.clone();
27                path.push(ident);
28                path
29            },
30            items: Vec::new(),
31            sub_modules: BTreeMap::new(),
32        })
33    }
34
35    /// Converts this module into syntax items
36    pub(crate) fn into_items(self) -> Vec<syn::Item> {
37        let mut items = self.items;
38
39        for (ident, module) in self.sub_modules {
40            let mod_items = module.into_items();
41            if !mod_items.is_empty() {
42                items.push(syn::parse_quote! {
43                    pub mod #ident {
44                        #(#mod_items)*
45                    }
46                });
47            }
48        }
49
50        items
51    }
52
53    /// Handles sub-type generation
54    fn handle_sub_type(
55        &mut self,
56        codegen: &Codegen,
57        ns: &NamespaceId,
58        in_submodule: bool,
59        name: impl AsRef<str>,
60        ty: &CedarType,
61    ) -> CodegenResult<syn::Type> {
62        let name = name.as_ref();
63        match ty {
64            CedarType::Bool => Ok(syn::parse_quote!(bool)),
65            CedarType::Long => Ok(syn::parse_quote!(i64)),
66            CedarType::String => Ok(syn::parse_quote!(::std::string::String)),
67            CedarType::Set(element_type) => {
68                let sub = self.handle_sub_type(codegen, ns, in_submodule, name, element_type.as_ref())?;
69                Ok(syn::parse_quote!(::std::vec::Vec<#sub>))
70            }
71            CedarType::Enum(variants) => Ok(self.handle_enum_type(codegen, name, variants, ns, in_submodule)),
72            CedarType::Record {
73                fields,
74                allows_additional,
75            } => self.handle_record_type(codegen, ns, name, fields, *allows_additional, in_submodule),
76            CedarType::Entity {
77                parents,
78                shape,
79                tag_type,
80            } => self.handle_entity_type(codegen, ns, name, shape, tag_type.as_deref(), parents, in_submodule),
81            CedarType::Reference(r) => self.handle_reference_type(codegen, r, in_submodule),
82        }
83    }
84
85    /// Handles enum type generation
86    fn handle_enum_type(
87        &mut self,
88        codegen: &Codegen,
89        name: &str,
90        variants: &BTreeSet<String>,
91        ns: &NamespaceId,
92        in_submodule: bool,
93    ) -> syn::Type {
94        let variants_def = variants.iter().map(|item| {
95            let ident = to_upper_camel_ident(item);
96            quote::quote! {
97                #[serde(rename = #item)]
98                #ident,
99            }
100        });
101
102        let variants_match = variants.iter().map(|item| {
103            let ident = to_upper_camel_ident(item);
104            quote::quote! {
105                Self::#ident => #item,
106            }
107        });
108
109        let crate_path = &codegen.config().crate_path;
110
111        let type_name = to_upper_camel_ident(name);
112        let serde_path = format!("{}::macro_exports::serde", quote::quote!(#crate_path));
113        self.items.push(syn::parse_quote! {
114            #[derive(#crate_path::macro_exports::serde_derive::Serialize)]
115            #[serde(crate = #serde_path)]
116            pub enum #type_name {
117                #(#variants_def)*
118            }
119        });
120
121        let full_name = format!("{ns}::{name}").trim_start_matches("::").to_string();
122        self.items.push(syn::parse_quote! {
123            impl #crate_path::CedarEntity for #type_name {
124                type TagType = #crate_path::NoTag;
125                type Id = Self;
126                type Attrs = #crate_path::NoAttributes;
127
128                const TYPE_NAME: #crate_path::EntityTypeName = #crate_path::entity_type_name!(#full_name);
129
130                fn entity_type_name() -> &'static #crate_path::macro_exports::cedar_policy::EntityTypeName {
131                    static ENTITY_TYPE_NAME: ::std::sync::LazyLock<#crate_path::macro_exports::cedar_policy::EntityTypeName> = ::std::sync::LazyLock::new(|| {
132                        std::str::FromStr::from_str(#full_name).expect("failed to parse entity type name - bug in scuffle-cedar-policy-codegen")
133                    });
134
135                    &*ENTITY_TYPE_NAME
136                }
137            }
138        });
139
140        self.items.push(syn::parse_quote! {
141            impl #crate_path::CedarId for #type_name {
142                fn into_smol_string(self) -> #crate_path::macro_exports::smol_str::SmolStr {
143                    let raw = match self {
144                        #(#variants_match)*
145                    };
146
147                    #crate_path::macro_exports::smol_str::SmolStr::from(raw)
148                }
149            }
150        });
151
152        self.items.push(syn::parse_quote! {
153            impl #crate_path::CedarEnumEntity for #type_name {
154                fn into_entity(self) -> #crate_path::Entity<Self>
155                    where Self: Sized
156                {
157                    #crate_path::Entity::builder(self, #crate_path::NoAttributes).build()
158                }
159            }
160        });
161
162        if in_submodule {
163            self.root_path
164                .last()
165                .map(|s| syn::parse_quote!(#s :: #type_name))
166                .unwrap_or_else(|| syn::parse_quote!(#type_name))
167        } else {
168            syn::parse_quote!(#type_name)
169        }
170    }
171
172    /// Handles record type generation
173    fn handle_record_type(
174        &mut self,
175        codegen: &Codegen,
176        ns: &NamespaceId,
177        name: &str,
178        fields: &BTreeMap<String, CedarTypeStructField>,
179        allows_additional: bool,
180        in_submodule: bool,
181    ) -> CodegenResult<syn::Type> {
182        if allows_additional {
183            return Err(CodegenError::Unsupported("record types with additional attributes".into()));
184        }
185
186        let type_name = to_upper_camel_ident(name);
187        let field_definitions = fields
188            .iter()
189            .map(|(field_name, field)| {
190                let ident = to_snake_ident(field_name);
191                let sub_type = self
192                    .sub_module(field_name)
193                    .handle_sub_type(codegen, ns, true, field_name, &field.ty)?;
194
195                let mut serde_attrs = vec![quote::quote!(rename = #field_name)];
196                let final_type = if field.optional {
197                    serde_attrs.push(quote::quote!(skip_serializing_if = "::std::option::Option::is_none"));
198                    syn::parse_quote!(::std::option::Option<#sub_type>)
199                } else {
200                    sub_type
201                };
202
203                Ok(quote::quote! {
204                    #[serde(#(#serde_attrs),*)]
205                    pub #ident: #final_type,
206                })
207            })
208            .collect::<CodegenResult<Vec<_>>>()?;
209
210        let crate_path = &codegen.config().crate_path;
211        let serde_path = format!("{}::macro_exports::serde", quote::quote!(#crate_path));
212
213        self.items.push(syn::parse_quote! {
214            #[derive(#crate_path::macro_exports::serde_derive::Serialize)]
215            #[serde(crate = #serde_path)]
216            pub struct #type_name {
217                #(#field_definitions)*
218            }
219        });
220
221        Ok(if in_submodule {
222            self.root_path
223                .last()
224                .map(|s| syn::parse_quote!(#s :: #type_name))
225                .unwrap_or_else(|| syn::parse_quote!(#type_name))
226        } else {
227            syn::parse_quote!(#type_name)
228        })
229    }
230
231    /// Handles entity type generation
232    #[allow(clippy::too_many_arguments)]
233    fn handle_entity_type(
234        &mut self,
235        codegen: &Codegen,
236        ns: &NamespaceId,
237        name: &str,
238        shape: &CedarType,
239        tag_type: Option<&CedarType>,
240        parents: &[CedarRef],
241        in_submodule: bool,
242    ) -> CodegenResult<syn::Type> {
243        let path = self.handle_sub_type(codegen, ns, false, name, shape)?;
244        let crate_path = &codegen.config().crate_path;
245
246        let tag_type = tag_type
247            .as_ref()
248            .map(|tag_type| {
249                self.sub_module(name)
250                    .handle_sub_type(codegen, ns, true, "EntityTag", tag_type)
251            })
252            .unwrap_or_else(|| Ok(syn::parse_quote!(#crate_path::NoTag)))?;
253
254        let full_name = format!("{ns}::{name}").trim_start_matches("::").to_string();
255
256        self.items.push(syn::parse_quote! {
257            impl #crate_path::CedarEntity for #path {
258                type TagType = #tag_type;
259                type Id = ::std::string::String;
260                type Attrs = Self;
261
262                const TYPE_NAME: #crate_path::EntityTypeName = #crate_path::entity_type_name!(#full_name);
263
264                fn entity_type_name() -> &'static #crate_path::macro_exports::cedar_policy::EntityTypeName {
265                    static ENTITY_TYPE_NAME: ::std::sync::LazyLock<#crate_path::macro_exports::cedar_policy::EntityTypeName> = ::std::sync::LazyLock::new(|| {
266                        std::str::FromStr::from_str(#full_name).expect("failed to parse entity type name - bug in scuffle-cedar-policy-codegen")
267                    });
268
269                    &*ENTITY_TYPE_NAME
270                }
271            }
272        });
273
274        for parent in parents {
275            match codegen.resolve_ref(parent) {
276                None => return Err(CodegenError::UnresolvedReference(parent.to_string())),
277                Some(p) if !p.is_entity() => {
278                    return Err(CodegenError::ExpectedEntity {
279                        common_type: parent.to_string(),
280                        ty: format!("entity {name}"),
281                    });
282                }
283                Some(_) => {}
284            }
285
286            let parent_ty = find_relative_path(&self.root_path, &parent.ident_path());
287            self.items.push(syn::parse_quote! {
288                impl #crate_path::CedarChild<#parent_ty> for #path {}
289            });
290        }
291
292        Ok(if in_submodule { syn::parse_quote!(super::#path) } else { path })
293    }
294
295    /// Handles reference type generation
296    pub(crate) fn handle_reference_type(
297        &self,
298        codegen: &Codegen,
299        r: &CedarRef,
300        in_submodule: bool,
301    ) -> CodegenResult<syn::Type> {
302        let relative = if in_submodule {
303            &self.root_path[..self.root_path.len() - 1]
304        } else {
305            &self.root_path
306        };
307
308        let path = find_relative_path(relative, &r.ident_path());
309        let Some(reference) = codegen.resolve_ref(r) else {
310            return Err(CodegenError::UnresolvedReference(r.to_string()));
311        };
312
313        let crate_path = &codegen.config().crate_path;
314
315        if reference.is_entity() {
316            Ok(syn::parse_quote!(#crate_path::EntityUid<#path>))
317        } else {
318            Ok(syn::parse_quote!(#path))
319        }
320    }
321
322    /// Handles top-level type generation
323    pub(crate) fn handle_type(
324        &mut self,
325        codegen: &Codegen,
326        ns: &NamespaceId,
327        name: impl AsRef<str>,
328        ty: &CedarType,
329    ) -> CodegenResult<()> {
330        match ty {
331            CedarType::Bool | CedarType::Long | CedarType::String => {
332                let type_name = to_upper_camel_ident(name.as_ref());
333                let sub_type = self.handle_sub_type(codegen, ns, false, name, ty)?;
334                self.items.push(syn::parse_quote! {
335                    pub type #type_name = #sub_type;
336                });
337            }
338            CedarType::Set(_) => {
339                let type_name = to_upper_camel_ident(name.as_ref());
340                let sub_type = self.sub_module(name).handle_sub_type(codegen, ns, true, "SetInner", ty)?;
341                self.items.push(syn::parse_quote! {
342                    type #type_name = #sub_type;
343                });
344            }
345            CedarType::Reference(_) => {
346                let type_name = to_upper_camel_ident(name.as_ref());
347                let sub_type = self.handle_sub_type(codegen, ns, false, name, ty)?;
348                self.items.push(syn::parse_quote! {
349                    type #type_name = #sub_type;
350                });
351            }
352            ty => {
353                self.handle_sub_type(codegen, ns, false, name, ty)?;
354            }
355        }
356
357        Ok(())
358    }
359
360    pub(crate) fn handle_action(
361        &mut self,
362        codegen: &Codegen,
363        ns_id: &NamespaceId,
364        action: &str,
365        ty: &CedarAction,
366    ) -> CodegenResult<(), CodegenError> {
367        let ident = to_upper_camel_ident(action);
368
369        // Generate action struct
370        self.items.push(syn::parse_quote! {
371            pub struct #ident;
372        });
373
374        let crate_path = &codegen.config().crate_path;
375
376        let ty_name = CedarRef {
377            id: Id::from_str("Action").unwrap(),
378            namespace: ns_id.clone(),
379        }
380        .to_string();
381
382        // Generate Serialize implementation
383        self.items.push(syn::parse_quote! {
384            impl #crate_path::CedarActionEntity for #ident {
385                fn action_entity_uid() -> &'static #crate_path::macro_exports::cedar_policy::EntityUid {
386                    static ENTITY_UID: ::std::sync::LazyLock<#crate_path::macro_exports::cedar_policy::EntityUid> = ::std::sync::LazyLock::new(|| {
387                        #crate_path::macro_exports::cedar_policy::EntityUid::from_type_name_and_id(
388                            std::str::FromStr::from_str(#ty_name).expect("failed to parse euid - bug in scuffle-cedar-policy-codegen"),
389                            std::str::FromStr::from_str(#action).expect("failed to parse euid - bug in scuffle-cedar-policy-codegen"),
390                        )
391                    });
392
393                    &*ENTITY_UID
394                }
395            }
396        });
397
398        for parent in &ty.parents {
399            if !codegen.contains_action(ns_id, parent) {
400                return Err(CodegenError::UnresolvedReference(parent.to_string()));
401            }
402
403            let parent_ident = to_upper_camel_ident(&parent.name);
404            let parent_path = if let Some(pid) = &parent.id {
405                let idents = pid
406                    .namespace
407                    .items
408                    .iter()
409                    .chain(std::iter::once(&pid.id))
410                    .map(to_snake_ident)
411                    .chain(std::iter::once(parent_ident))
412                    .collect::<Vec<_>>();
413
414                find_relative_path(&self.root_path, &idents)
415            } else {
416                syn::parse_quote!(#parent_ident)
417            };
418
419            self.items.push(syn::parse_quote! {
420                impl #crate_path::CedarChild<#parent_path> for #ident {}
421            });
422        }
423
424        // Generate context type
425        let ctx = ty
426            .context
427            .as_ref()
428            .map(|ctx| self.sub_module(action).handle_sub_type(codegen, ns_id, true, "Context", ctx))
429            .unwrap_or_else(|| Ok(syn::parse_quote!(#crate_path::EmptyContext)))?;
430
431        // Generate action implementations
432        let resolve_types = |reference| match codegen.resolve_ref(reference) {
433            None => Err(CodegenError::UnresolvedReference(reference.to_string())),
434            Some(r) if !r.is_entity() => Err(CodegenError::ExpectedEntity {
435                common_type: reference.to_string(),
436                ty: format!("action {action}"),
437            }),
438            Some(_) => Ok(find_relative_path(&self.root_path, &reference.ident_path())),
439        };
440
441        let principals = ty.principals.iter().map(resolve_types).collect::<CodegenResult<Vec<_>>>()?;
442        let resources = ty.resources.iter().map(resolve_types).collect::<CodegenResult<Vec<_>>>()?;
443
444        for principal in principals {
445            for resource in &resources {
446                self.items.push(syn::parse_quote! {
447                    impl #crate_path::CedarAction<#principal, #resource> for #ident {
448                        type Context = #ctx;
449                    }
450                });
451            }
452        }
453
454        Ok(())
455    }
456}