diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..3687a9e --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,92 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +jobs: + test: + name: Test (unit + sqlite) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - name: cargo test --all + run: cargo test --all --all-features + + fmt: + name: Format check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - name: cargo fmt --check + run: cargo fmt --all -- --check + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: clippy + - uses: Swatinem/rust-cache@v2 + - name: cargo clippy + run: cargo clippy --all-targets --all-features -- -D warnings + + test-postgres: + name: Test (Postgres E2E) + runs-on: ubuntu-latest + services: + postgres: + image: postgres:16 + env: + POSTGRES_PASSWORD: postgres + POSTGRES_DB: sqlx_gen_test + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U postgres" + --health-interval 5s + --health-timeout 3s + --health-retries 10 + env: + PG_URL: postgres://postgres:postgres@localhost:5432/sqlx_gen_test + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - name: cargo test postgres e2e + run: cargo test --all-features --test e2e_postgres -- --include-ignored + continue-on-error: true + + test-mysql: + name: Test (MySQL E2E) + runs-on: ubuntu-latest + services: + mysql: + image: mysql:8.0 + env: + MYSQL_ROOT_PASSWORD: root + MYSQL_DATABASE: sqlx_gen_test + ports: + - 3306:3306 + options: >- + --health-cmd "mysqladmin ping -uroot -proot" + --health-interval 5s + --health-timeout 3s + --health-retries 10 + env: + MYSQL_URL: mysql://root:root@localhost:3306/sqlx_gen_test + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - name: cargo test mysql e2e + run: cargo test --all-features --test e2e_mysql -- --include-ignored + continue-on-error: true diff --git a/.gitignore b/.gitignore index ccb5166..4e8875d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ /target -.vscode \ No newline at end of file +.vscode +.DS_Store +**/.DS_Store +/docs/superpowers/ \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 46b70d8..19bdacc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1499,7 +1499,7 @@ dependencies = [ [[package]] name = "sqlx-gen" -version = "0.5.5" +version = "0.5.6" dependencies = [ "clap", "env_logger", @@ -1519,7 +1519,7 @@ dependencies = [ [[package]] name = "sqlx-gen-macros" -version = "0.5.5" +version = "0.5.6" [[package]] name = "sqlx-macros" diff --git a/Cargo.toml b/Cargo.toml index 9d661dd..ecd374f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,43 @@ [workspace] members = ["crates/sqlx_gen", "crates/sqlx_gen_macros"] resolver = "2" + +# Shared package metadata. Each crate inherits via `*.workspace = true`, +# so version bumps happen here and propagate everywhere — sqlx-gen and +# sqlx-gen-macros can never drift apart. +[workspace.package] +version = "0.5.6" +edition = "2021" +rust-version = "1.75" +license = "MIT" +repository = "https://github.com/LeadcodeDev/sqlx-gen" +keywords = ["sqlx", "codegen", "postgres", "mysql", "sqlite"] +categories = ["database", "development-tools"] + +# Shared dependency definitions. The internal sqlx-gen-macros entry pins the +# version to whatever sqlx_gen itself ships, so the proc-macro and runtime +# crates are always released together. +[workspace.dependencies] +sqlx-gen-macros = { path = "crates/sqlx_gen_macros", version = "0.5.6" } +sqlx = { version = "0.8", features = [ + "runtime-tokio", + "tls-rustls-ring", + "postgres", + "mysql", + "sqlite", + "chrono", + "uuid", + "json", +] } +tokio = { version = "1", features = ["full"] } +clap = { version = "4", features = ["derive", "env"] } +heck = "0.5" +thiserror = "2" +quote = "1" +proc-macro2 = "1" +syn = "2" +prettyplease = "0.2" +log = "0.4" +env_logger = "0.11" +tempfile = "3" +pretty_assertions = "1" diff --git a/README.md b/README.md index 021a511..9b4a5e3 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,40 @@ All generated types include `#[sqlx_gen(...)]` annotations for tooling: | Domain type | `#[sqlx_gen(kind = "domain")]` | | Primary key field | `#[sqlx_gen(primary_key)]` | +## PostgreSQL — multi-schema setup + +When you introspect more than one schema (`-s auth,billing,public`), enums, +composite types, and domains carry an unqualified +`#[sqlx(type_name = "...")]` because `sqlx::postgres::PgTypeInfo::with_name` +does not accept `schema.type`. For PG to resolve those types at runtime, the +connection must include every non-default schema in its `search_path`. + +`sqlx-gen` prints the exact `SET search_path` snippet it needs after +introspection. Apply it on every new connection via an `after_connect` +hook: + +```rust +use sqlx::postgres::PgPoolOptions; + +let pool = PgPoolOptions::new() + .after_connect(|conn, _meta| Box::pin(async move { + sqlx::query("SET search_path TO public, auth, billing") + .execute(conn).await?; + Ok(()) + })) + .connect(&url).await?; +``` + +If two schemas declare a type with the same name (e.g. both `auth.role` and +`billing.role`), sqlx-gen prefixes the Rust identifier with the schema +PascalCase form (`AuthRole`, `BillingRole`) to keep the generated code +unambiguous. The bare PascalCase form (`Role`) is reserved for the default +schema and for unique names. + +The `sqlx_gen::codegen::required_pg_search_path(&schema_info)` helper returns +the list of non-default schemas you need to include — handy when wiring this +into a build script. + ## License MIT diff --git a/crates/sqlx_gen/Cargo.toml b/crates/sqlx_gen/Cargo.toml index 5ccfe56..7769b84 100644 --- a/crates/sqlx_gen/Cargo.toml +++ b/crates/sqlx_gen/Cargo.toml @@ -1,13 +1,14 @@ [package] name = "sqlx-gen" -version = "0.5.5" -edition = "2021" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true description = "Generate Rust structs from database schema introspection" -license = "MIT" -repository = "https://github.com/LeadcodeDev/sqlx-gen" readme = "../../README.md" -keywords = ["sqlx", "codegen", "postgres", "mysql", "sqlite"] -categories = ["database", "development-tools"] [[bin]] name = "sqlx-gen" @@ -27,31 +28,24 @@ cli = [ "dep:prettyplease", "dep:log", "dep:env_logger", + "dep:tempfile", ] [dependencies] -sqlx-gen-macros = { path = "../sqlx_gen_macros", version = "0.5.4" } -sqlx = { version = "0.8", features = [ - "runtime-tokio", - "tls-rustls-ring", - "postgres", - "mysql", - "sqlite", - "chrono", - "uuid", - "json", -], optional = true } -tokio = { version = "1", features = ["full"], optional = true } -clap = { version = "4", features = ["derive", "env"], optional = true } -heck = { version = "0.5", optional = true } -thiserror = { version = "2", optional = true } -quote = { version = "1", optional = true } -proc-macro2 = { version = "1", optional = true } -syn = { version = "2", optional = true } -prettyplease = { version = "0.2", optional = true } -log = { version = "0.4", optional = true } -env_logger = { version = "0.11", optional = true } +sqlx-gen-macros.workspace = true +sqlx = { workspace = true, optional = true } +tokio = { workspace = true, optional = true } +clap = { workspace = true, optional = true } +heck = { workspace = true, optional = true } +thiserror = { workspace = true, optional = true } +quote = { workspace = true, optional = true } +proc-macro2 = { workspace = true, optional = true } +syn = { workspace = true, optional = true } +prettyplease = { workspace = true, optional = true } +log = { workspace = true, optional = true } +env_logger = { workspace = true, optional = true } +tempfile = { workspace = true, optional = true } [dev-dependencies] -pretty_assertions = "1" -tempfile = "3" +pretty_assertions.workspace = true +tempfile.workspace = true diff --git a/crates/sqlx_gen/src/cli.rs b/crates/sqlx_gen/src/cli.rs index 18e322f..eb9364d 100644 --- a/crates/sqlx_gen/src/cli.rs +++ b/crates/sqlx_gen/src/cli.rs @@ -3,7 +3,10 @@ use std::collections::HashMap; use std::path::PathBuf; #[derive(Parser, Debug)] -#[command(name = "sqlx-gen", about = "Generate Rust structs from database schema")] +#[command( + name = "sqlx-gen", + about = "Generate Rust structs from database schema" +)] pub struct Cli { #[command(subcommand)] pub command: Command, @@ -91,6 +94,11 @@ pub struct EntitiesArgs { #[arg(long, default_value = "chrono")] pub time_crate: TimeCrate, + /// How to render PostgreSQL domains: `alias` (`pub type X = Y;`) or + /// `newtype` (`pub struct X(pub Y);` with `#[sqlx(transparent)]`). + #[arg(long, default_value = "alias")] + pub domain_style: DomainStyle, + /// Print to stdout without writing files #[arg(short = 'n', long)] pub dry_run: bool, @@ -106,6 +114,41 @@ impl EntitiesArgs { }) .collect() } + + /// Parse and validate `--type-overrides`. Each value must be a syntactically + /// valid Rust type (parseable by `syn::parse_str::`). Prevents + /// injection of arbitrary Rust into generated code. + pub fn parse_type_overrides_checked(&self) -> crate::error::Result> { + let mut map = HashMap::new(); + for s in &self.type_overrides { + let (k, v) = s.split_once('=').ok_or_else(|| { + crate::error::Error::Config(format!( + "Invalid --type-overrides entry '{}'. Expected format: sql_type=RustType", + s + )) + })?; + if k.is_empty() { + return Err(crate::error::Error::Config(format!( + "Empty SQL type key in --type-overrides entry '{}'", + s + ))); + } + if v.trim().is_empty() { + return Err(crate::error::Error::Config(format!( + "Empty Rust type value in --type-overrides entry '{}'", + s + ))); + } + syn::parse_str::(v).map_err(|e| { + crate::error::Error::Config(format!( + "Invalid Rust type in --type-overrides value '{}': {}", + v, e + )) + })?; + map.insert(k.to_string(), v.to_string()); + } + Ok(map) + } } #[derive(Parser, Debug)] @@ -131,7 +174,6 @@ pub struct CrudArgs { #[arg(short = 'm', long, value_delimiter = ',')] pub methods: Vec, - /// Use sqlx::query_as!() compile-time checked macros instead of query_as::<_, T>() functions #[arg(short = 'q', long)] pub query_macro: bool, @@ -199,6 +241,40 @@ pub enum DatabaseKind { Sqlite, } +/// How a Postgres domain should be rendered in Rust. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum DomainStyle { + /// `pub type Email = String;` — transparent alias, zero overhead. + #[default] + Alias, + /// `pub struct Email(pub String);` with `#[sqlx(transparent)]` — preserves + /// type identity so user code can attach `impl` blocks / validation. + Newtype, +} + +impl std::str::FromStr for DomainStyle { + type Err = String; + fn from_str(s: &str) -> Result { + match s { + "alias" => Ok(Self::Alias), + "newtype" => Ok(Self::Newtype), + other => Err(format!( + "Unknown domain style '{}'. Expected: alias, newtype", + other + )), + } + } +} + +impl std::fmt::Display for DomainStyle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Alias => write!(f, "alias"), + Self::Newtype => write!(f, "newtype"), + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum TimeCrate { #[default] @@ -268,7 +344,16 @@ pub struct Methods { pub delete: bool, } -const ALL_METHODS: &[&str] = &["get_all", "paginate", "get", "insert", "insert_many", "update", "overwrite", "delete"]; +const ALL_METHODS: &[&str] = &[ + "get_all", + "paginate", + "get", + "insert", + "insert_many", + "update", + "overwrite", + "delete", +]; impl Methods { /// Parse a list of method names. `"*"` enables all methods. @@ -333,6 +418,7 @@ mod tests { exclude_tables: None, views: false, time_crate: TimeCrate::Chrono, + domain_style: DomainStyle::Alias, dry_run: false, } } @@ -419,6 +505,62 @@ mod tests { assert!(args.parse_type_overrides().is_empty()); } + // ========== parse_type_overrides_checked ========== + + #[test] + fn test_overrides_checked_empty_ok() { + let args = make_entities_args_with_overrides(vec![]); + assert!(args.parse_type_overrides_checked().unwrap().is_empty()); + } + + #[test] + fn test_overrides_checked_simple_type() { + let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]); + let map = args.parse_type_overrides_checked().unwrap(); + assert_eq!(map.get("jsonb").unwrap(), "MyJson"); + } + + #[test] + fn test_overrides_checked_path_type() { + let args = make_entities_args_with_overrides(vec!["jsonb=crate::types::MyJson"]); + let map = args.parse_type_overrides_checked().unwrap(); + assert_eq!(map.get("jsonb").unwrap(), "crate::types::MyJson"); + } + + #[test] + fn test_overrides_checked_generic_type() { + let args = make_entities_args_with_overrides(vec!["bytea=Vec"]); + assert!(args.parse_type_overrides_checked().is_ok()); + } + + #[test] + fn test_overrides_checked_rejects_injection() { + let args = make_entities_args_with_overrides(vec!["jsonb=Vec; fn pwned() {}"]); + let result = args.parse_type_overrides_checked(); + assert!( + result.is_err(), + "must reject value that isn't a single Rust type" + ); + } + + #[test] + fn test_overrides_checked_rejects_no_equals() { + let args = make_entities_args_with_overrides(vec!["noequals"]); + assert!(args.parse_type_overrides_checked().is_err()); + } + + #[test] + fn test_overrides_checked_rejects_empty_value() { + let args = make_entities_args_with_overrides(vec!["jsonb="]); + assert!(args.parse_type_overrides_checked().is_err()); + } + + #[test] + fn test_overrides_checked_rejects_empty_key() { + let args = make_entities_args_with_overrides(vec!["=Foo"]); + assert!(args.parse_type_overrides_checked().is_err()); + } + #[test] fn test_overrides_single() { let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]); @@ -481,9 +623,16 @@ mod tests { #[test] fn test_exclude_tables_set() { let mut args = make_entities_args_with_overrides(vec![]); - args.exclude_tables = Some(vec!["_migrations".to_string(), "schema_versions".to_string()]); + args.exclude_tables = Some(vec![ + "_migrations".to_string(), + "schema_versions".to_string(), + ]); assert_eq!(args.exclude_tables.as_ref().unwrap().len(), 2); - assert!(args.exclude_tables.as_ref().unwrap().contains(&"_migrations".to_string())); + assert!(args + .exclude_tables + .as_ref() + .unwrap() + .contains(&"_migrations".to_string())); } // ========== methods ========== @@ -583,7 +732,10 @@ mod tests { #[test] fn test_module_path_nested() { let p = PathBuf::from("src/db/entities/agent.rs"); - assert_eq!(module_path_from_file(&p).unwrap(), "crate::db::entities::agent"); + assert_eq!( + module_path_from_file(&p).unwrap(), + "crate::db::entities::agent" + ); } #[test] diff --git a/crates/sqlx_gen/src/codegen/composite_gen.rs b/crates/sqlx_gen/src/codegen/composite_gen.rs index 0b1342f..ff4def8 100644 --- a/crates/sqlx_gen/src/codegen/composite_gen.rs +++ b/crates/sqlx_gen/src/codegen/composite_gen.rs @@ -1,11 +1,11 @@ use std::collections::{BTreeSet, HashMap}; -use heck::{ToSnakeCase, ToUpperCamelCase}; +use heck::ToSnakeCase; use proc_macro2::TokenStream; use quote::{format_ident, quote}; use crate::cli::{DatabaseKind, TimeCrate}; -use crate::codegen::{imports_for_derives, is_rust_keyword}; +use crate::codegen::{imports_for_derives, is_rust_keyword, rust_type_name_for}; use crate::introspect::{CompositeTypeInfo, SchemaInfo}; use crate::typemap; @@ -21,7 +21,21 @@ pub fn generate_composite( for imp in imports_for_derives(extra_derives) { imports.insert(imp); } - let struct_name = format_ident!("{}", composite.name.to_upper_camel_case()); + let rust_name = rust_type_name_for(schema_info, &composite.schema_name, &composite.name); + let struct_name = format_ident!("{}", rust_name); + let search_path_doc = if db_kind == DatabaseKind::Postgres + && !crate::codegen::is_default_schema(&composite.schema_name) + { + Some(format!( + "Lives in PostgreSQL schema `{schema}`. The sqlx connection must \ + include `{schema}` in its search_path so PG resolves the \ + unqualified `type_name = \"{name}\"` to this composite.", + schema = composite.schema_name, + name = composite.name, + )) + } else { + None + }; let doc = format!( "Composite type: {}.{}", @@ -45,29 +59,25 @@ pub fn generate_composite( derive_tokens.push(quote! { #ident }); } - let pg_name = if composite.schema_name != "public" { - format!("{}.{}", composite.schema_name, composite.name) - } else { - composite.name.clone() - }; + // Always unqualified — sqlx 0.8's PgTypeInfo::with_name does not accept "schema.type" + // and emitting it triggers runtime decode errors. Non-public schemas require the + // connection's `search_path` to include the schema. + let pg_name = &composite.name; let type_attr = quote! { #[sqlx(type_name = #pg_name)] }; let fields: Vec = composite .fields .iter() .map(|col| { - let rust_type = typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate); + let rust_type = + typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate); if let Some(imp) = &rust_type.needs_import { imports.insert(imp.clone()); } let field_name_snake = col.name.to_snake_case(); let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) { - let prefixed = format!( - "{}_{}", - composite.name.to_snake_case(), - field_name_snake - ); + let prefixed = format!("{}_{}", composite.name.to_snake_case(), field_name_snake); (prefixed, true) } else { let changed = field_name_snake != col.name; @@ -94,8 +104,18 @@ pub fn generate_composite( }) .collect(); + // `#[derive(sqlx::Type)]` with `#[sqlx(type_name = "x")]` auto-generates + // `impl PgHasArrayType` returning `_x`. Emitting a second impl triggers + // E0119 in the user's crate. + let _ = db_kind; + + let search_path_doc_tokens = match &search_path_doc { + Some(m) => quote! { #[doc = #m] }, + None => quote! {}, + }; let tokens = quote! { #[doc = #doc] + #search_path_doc_tokens #[derive(#(#derive_tokens),*)] #[sqlx_gen(kind = "composite")] #type_attr @@ -130,14 +150,22 @@ mod tests { is_primary_key: false, ordinal_position: 0, schema_name: "public".to_string(), + udt_schema: None, column_default: None, } } fn gen(composite: &CompositeTypeInfo) -> String { let schema = SchemaInfo::default(); - let (tokens, _) = generate_composite(composite, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), TimeCrate::Chrono); - parse_and_format(&tokens) + let (tokens, _) = generate_composite( + composite, + DatabaseKind::Postgres, + &schema, + &[], + &HashMap::new(), + TimeCrate::Chrono, + ); + parse_and_format(&tokens).unwrap() } fn gen_with( @@ -146,18 +174,28 @@ mod tests { overrides: &HashMap, ) -> (String, BTreeSet) { let schema = SchemaInfo::default(); - let (tokens, imports) = generate_composite(composite, DatabaseKind::Postgres, &schema, derives, overrides, TimeCrate::Chrono); - (parse_and_format(&tokens), imports) + let (tokens, imports) = generate_composite( + composite, + DatabaseKind::Postgres, + &schema, + derives, + overrides, + TimeCrate::Chrono, + ); + (parse_and_format(&tokens).unwrap(), imports) } // --- basic structure --- #[test] fn test_simple_composite() { - let c = make_composite("address", vec![ - make_field("street", "text", false), - make_field("city", "text", false), - ]); + let c = make_composite( + "address", + vec![ + make_field("street", "text", false), + make_field("city", "text", false), + ], + ); let code = gen(&c); assert!(code.contains("pub street: String")); assert!(code.contains("pub city: String")); @@ -185,16 +223,42 @@ mod tests { } #[test] - fn test_non_public_schema_qualified_type_name() { + fn test_does_not_emit_manual_pg_has_array_type_impl() { + // Regression for E0119 — `#[derive(sqlx::Type)]` already provides this + // impl when `type_name` is set, so emitting our own conflicted. + let c = make_composite("address", vec![make_field("street", "text", false)]); + let code = gen(&c); + assert!( + !code.contains("PgHasArrayType"), + "must not emit a manual PgHasArrayType impl, got:\n{}", + code + ); + } + + #[test] + fn test_non_public_schema_type_name_is_unqualified() { + // Regression: previously emitted "geo.point" which crashes sqlx 0.8 at runtime. let c = CompositeTypeInfo { schema_name: "geo".to_string(), name: "point".to_string(), fields: vec![make_field("x", "float8", false)], }; let schema = SchemaInfo::default(); - let (tokens, _) = generate_composite(&c, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), TimeCrate::Chrono); - let code = parse_and_format(&tokens); - assert!(code.contains("sqlx(type_name = \"geo.point\")")); + let (tokens, _) = generate_composite( + &c, + DatabaseKind::Postgres, + &schema, + &[], + &HashMap::new(), + TimeCrate::Chrono, + ); + let code = parse_and_format(&tokens).unwrap(); + assert!( + code.contains("sqlx(type_name = \"point\")"), + "type_name must be unqualified for sqlx 0.8, got:\n{}", + code + ); + assert!(!code.contains("\"geo.point\"")); } #[test] diff --git a/crates/sqlx_gen/src/codegen/crud_gen.rs b/crates/sqlx_gen/src/codegen/crud_gen.rs index 1623e53..536e36f 100644 --- a/crates/sqlx_gen/src/codegen/crud_gen.rs +++ b/crates/sqlx_gen/src/codegen/crud_gen.rs @@ -5,6 +5,7 @@ use quote::{format_ident, quote}; use crate::cli::{DatabaseKind, Methods, PoolVisibility}; use crate::codegen::entity_parser::{ParsedEntity, ParsedField}; +use crate::codegen::identifiers::{quote_ident, quote_qualified}; pub fn generate_crud_from_parsed( entity: &ParsedEntity, @@ -20,10 +21,15 @@ pub fn generate_crud_from_parsed( let repo_name = format!("{}Repository", entity.struct_name); let repo_ident = format_ident!("{}", repo_name); - let table_name = match &entity.schema_name { - Some(schema) => format!("{}.{}", schema, entity.table_name), - None => entity.table_name.clone(), - }; + // Skip schema qualification for the well-known default schema of each + // backend ("public" for Postgres, "main" for SQLite, "dbo" for SQL Server- + // adjacent flows). Avoids verbose "public"."users" everywhere when the + // user has no other schema in play. + let schema_for_sql = entity + .schema_name + .as_deref() + .filter(|s| !crate::codegen::is_default_schema(s)); + let table_name = quote_qualified(schema_for_sql, &entity.table_name, db_kind); // Pool type (used via full path sqlx::PgPool etc., no import needed) let pool_type = pool_type_tokens(db_kind); @@ -246,9 +252,20 @@ pub fn generate_crud_from_parsed( if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) { let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name); - // When all columns are PKs (e.g. junction tables), use pk_fields for insert + // Insert source fields: + // - Junction table (all-PK): use the PKs themselves so something gets inserted. + // - Composite PK (>1) with extra columns: include the PKs alongside non-PK + // columns. Composite PKs are typically NOT auto-generated, so omitting + // them would produce a NOT NULL violation. The MySQL branch below also + // relies on the bound PK values when LAST_INSERT_ID is not applicable. + // - Single PK + extras: keep the legacy behaviour (exclude the PK and + // assume it's SERIAL/AUTO_INCREMENT). let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() { pk_fields.clone() + } else if pk_fields.len() > 1 { + let mut combined: Vec<&ParsedField> = pk_fields.clone(); + combined.extend(non_pk_fields.iter().copied()); + combined } else { non_pk_fields.clone() }; @@ -268,9 +285,9 @@ pub fn generate_crud_from_parsed( }) .collect(); - let col_names: Vec<&str> = insert_source_fields + let col_names: Vec = insert_source_fields .iter() - .map(|f| f.column_name.as_str()) + .map(|f| quote_ident(&f.column_name, db_kind)) .collect(); let col_list = col_names.join(", "); @@ -355,9 +372,9 @@ pub fn generate_crud_from_parsed( non_pk_fields.clone() }; - let col_names: Vec<&str> = insert_source_fields + let col_names: Vec = insert_source_fields .iter() - .map(|f| f.column_name.as_str()) + .map(|f| quote_ident(&f.column_name, db_kind)) .collect(); let col_list = col_names.join(", "); let num_cols = insert_source_fields.len(); @@ -437,7 +454,7 @@ pub fn generate_crud_from_parsed( .enumerate() .map(|(i, f)| { let p = placeholder(db_kind, i + 1); - format!("{} = {}", f.column_name, p) + format!("{} = {}", quote_ident(&f.column_name, db_kind), p) }) .collect(); let set_clause = set_cols.join(",\n "); @@ -447,7 +464,7 @@ pub fn generate_crud_from_parsed( .enumerate() .map(|(i, f)| { let p = placeholder_with_cast(db_kind, i + 1, f); - format!("{} = {}", f.column_name, p) + format!("{} = {}", quote_ident(&f.column_name, db_kind), p) }) .collect(); let set_clause_cast = set_cols_cast.join(",\n "); @@ -467,6 +484,7 @@ pub fn generate_crud_from_parsed( format!("UPDATE {}\nSET\n {}\nWHERE {}", table_name, sc, wc) } }; + let sql = raw_sql_lit(&build_overwrite_sql(&set_clause, &where_clause)); let sql_macro = raw_sql_lit(&build_overwrite_sql(&set_clause_cast, &where_clause_cast)); @@ -615,7 +633,8 @@ pub fn generate_crud_from_parsed( .enumerate() .map(|(i, f)| { let p = placeholder(db_kind, i + 1); - format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p) + let col = quote_ident(&f.column_name, db_kind); + format!("{col} = COALESCE({p}, {col})", col = col, p = p) }) .collect(); let set_clause = set_cols.join(",\n "); @@ -626,7 +645,8 @@ pub fn generate_crud_from_parsed( .enumerate() .map(|(i, f)| { let p = placeholder_with_cast(db_kind, i + 1, f); - format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p) + let col = quote_ident(&f.column_name, db_kind); + format!("{col} = COALESCE({p}, {col})", col = col, p = p) }) .collect(); let set_clause_cast = set_cols_cast.join(",\n "); @@ -851,14 +871,26 @@ fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream { } } -/// Wraps a SQL string as a raw string literal `r#"..."#` in the generated code. -/// Multi-line SQL gets a leading newline so each clause starts on its own line. +/// Wraps a SQL string as a raw string literal `r#"..."#` (or `r##"..."##` +/// when the body contains `"#`) in the generated code. Multi-line SQL gets +/// a leading newline so each clause starts on its own line. fn raw_sql_lit(s: &str) -> TokenStream { - if s.contains('\n') { - format!("r#\"\n{}\n\"#", s).parse().unwrap() + // Pick the smallest number of `#` characters whose fence isn't present in + // the body. Quoted SQL identifiers (`"users"`) followed by `#` are rare in + // practice but a malicious or quirky DB name could trip the default fence. + let mut hashes = 1usize; + while s.contains(&format!("\"{}", "#".repeat(hashes))) { + hashes += 1; + } + let fence = "#".repeat(hashes); + let body = if s.contains('\n') { + format!("\n{}\n", s) } else { - format!("r#\"{}\"#", s).parse().unwrap() - } + s.to_string() + }; + format!("r{fence}\"{body}\"{fence}", fence = fence, body = body) + .parse() + .expect("raw_sql_lit must produce a valid Rust raw string literal") } fn placeholder(db_kind: DatabaseKind, index: usize) -> String { @@ -887,7 +919,7 @@ fn build_where_clause_parsed( .enumerate() .map(|(i, f)| { let p = placeholder(db_kind, start_index + i); - format!("{} = {}", f.column_name, p) + format!("{} = {}", quote_ident(&f.column_name, db_kind), p) }) .collect::>() .join(" AND ") @@ -918,7 +950,7 @@ fn build_where_clause_cast( .enumerate() .map(|(i, f)| { let p = placeholder_with_cast(db_kind, start_index + i, f); - format!("{} = {}", f.column_name, p) + format!("{} = {}", quote_ident(&f.column_name, db_kind), p) }) .collect::>() .join(" AND ") @@ -962,18 +994,42 @@ fn build_insert_method_parsed( "SELECT *\nFROM {}\nWHERE {}", table_name, pk_where )); - let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID() as id"); - quote! { - pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> { - sqlx::query!(#sql_macro, #(#macro_args),*) - .execute(&self.pool) - .await?; - let id = sqlx::query_scalar!(#last_insert_id_sql) - .fetch_one(&self.pool) - .await?; - sqlx::query_as!(#entity_ident, #select_sql, id) - .fetch_one(&self.pool) - .await + if pk_fields.len() > 1 { + // Composite PK → the user supplied every PK column in + // params, so SELECT by them directly. LAST_INSERT_ID is + // meaningless for composite keys (only the first + // auto-increment column populates it, if any). + let pk_macro_args: Vec = pk_fields + .iter() + .map(|f| { + let name = format_ident!("{}", f.rust_name); + quote! { params.#name } + }) + .collect(); + quote! { + pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> { + sqlx::query!(#sql_macro, #(#macro_args),*) + .execute(&self.pool) + .await?; + sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*) + .fetch_one(&self.pool) + .await + } + } + } else { + let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID() as id"); + quote! { + pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> { + sqlx::query!(#sql_macro, #(#macro_args),*) + .execute(&self.pool) + .await?; + let id = sqlx::query_scalar!(#last_insert_id_sql) + .fetch_one(&self.pool) + .await?; + sqlx::query_as!(#entity_ident, #select_sql, id) + .fetch_one(&self.pool) + .await + } } } } @@ -996,20 +1052,42 @@ fn build_insert_method_parsed( "SELECT *\nFROM {}\nWHERE {}", table_name, pk_where )); - let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()"); - quote! { - pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> { - sqlx::query(#sql) - #(#binds)* - .execute(&self.pool) - .await?; - let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql) - .fetch_one(&self.pool) - .await?; - sqlx::query_as::<_, #entity_ident>(#select_sql) - .bind(id) - .fetch_one(&self.pool) - .await + if pk_fields.len() > 1 { + let pk_binds: Vec = pk_fields + .iter() + .map(|f| { + let name = format_ident!("{}", f.rust_name); + quote! { .bind(¶ms.#name) } + }) + .collect(); + quote! { + pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> { + sqlx::query(#sql) + #(#binds)* + .execute(&self.pool) + .await?; + sqlx::query_as::<_, #entity_ident>(#select_sql) + #(#pk_binds)* + .fetch_one(&self.pool) + .await + } + } + } else { + let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()"); + quote! { + pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> { + sqlx::query(#sql) + #(#binds)* + .execute(&self.pool) + .await?; + let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql) + .fetch_one(&self.pool) + .await?; + sqlx::query_as::<_, #entity_ident>(#select_sql) + .bind(id) + .fetch_one(&self.pool) + .await + } } } } @@ -1123,27 +1201,54 @@ fn build_insert_many_transactionally_method( "SELECT *\nFROM {}\nWHERE {}", table_name, pk_where )); - let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()"); - quote! { - let mut tx = self.pool.begin().await?; - let mut results = Vec::with_capacity(entries.len()); - for params in &entries { - sqlx::query(#single_insert_sql) - #(#single_binds)* - .execute(&mut *tx) - .await?; - let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql) - .fetch_one(&mut *tx) - .await?; - let row = sqlx::query_as::<_, #entity_ident>(#select_sql) - .bind(id) - .fetch_one(&mut *tx) - .await?; - results.push(row); + if pk_fields.len() > 1 { + let pk_binds: Vec = pk_fields + .iter() + .map(|f| { + let name = format_ident!("{}", f.rust_name); + quote! { .bind(¶ms.#name) } + }) + .collect(); + quote! { + let mut tx = self.pool.begin().await?; + let mut results = Vec::with_capacity(entries.len()); + for params in &entries { + sqlx::query(#single_insert_sql) + #(#single_binds)* + .execute(&mut *tx) + .await?; + let row = sqlx::query_as::<_, #entity_ident>(#select_sql) + #(#pk_binds)* + .fetch_one(&mut *tx) + .await?; + results.push(row); + } + tx.commit().await?; + Ok(results) + } + } else { + let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()"); + quote! { + let mut tx = self.pool.begin().await?; + let mut results = Vec::with_capacity(entries.len()); + for params in &entries { + sqlx::query(#single_insert_sql) + #(#single_binds)* + .execute(&mut *tx) + .await?; + let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql) + .fetch_one(&mut *tx) + .await?; + let row = sqlx::query_as::<_, #entity_ident>(#select_sql) + .bind(id) + .fetch_one(&mut *tx) + .await?; + results.push(row); + } + tx.commit().await?; + Ok(results) } - tx.commit().await?; - Ok(results) } } }; @@ -1153,6 +1258,12 @@ fn build_insert_many_transactionally_method( &self, entries: Vec<#insert_params_ident>, ) -> Result, sqlx::Error> { + // Short-circuit empty batches: avoids opening a transaction and + // sending a zero-row INSERT (or a "VALUES " with no tuples, which + // Postgres rejects as a syntax error). + if entries.is_empty() { + return Ok(Vec::new()); + } #body } } @@ -1273,7 +1384,7 @@ mod tests { false, PoolVisibility::Private, ); - parse_and_format(&tokens) + parse_and_format(&tokens).unwrap() } fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String { @@ -1286,7 +1397,7 @@ mod tests { true, PoolVisibility::Private, ); - parse_and_format(&tokens) + parse_and_format(&tokens).unwrap() } fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String { @@ -1298,7 +1409,7 @@ mod tests { false, PoolVisibility::Private, ); - parse_and_format(&tokens) + parse_and_format(&tokens).unwrap() } fn gen_with_tab_spaces(entity: &ParsedEntity, db: DatabaseKind, tab_spaces: usize) -> String { @@ -1311,7 +1422,7 @@ mod tests { false, PoolVisibility::Private, ); - parse_and_format_with_tab_spaces(&tokens, tab_spaces) + parse_and_format_with_tab_spaces(&tokens, tab_spaces).unwrap() } // --- basic structure --- @@ -1345,7 +1456,7 @@ mod tests { false, PoolVisibility::Pub, ); - let code = parse_and_format(&tokens); + let code = parse_and_format(&tokens).unwrap(); assert!( code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool") ); @@ -1362,7 +1473,7 @@ mod tests { false, PoolVisibility::PubCrate, ); - let code = parse_and_format(&tokens); + let code = parse_and_format(&tokens).unwrap(); assert!( code.contains("pub(crate) pool: sqlx::PgPool") || code.contains("pub(crate) pool: sqlx :: PgPool") @@ -2239,7 +2350,7 @@ mod tests { true, PoolVisibility::Private, ); - let code = parse_and_format(&tokens); + let code = parse_and_format(&tokens).unwrap(); assert!(code.contains("query_as!")); assert!(!code.contains("query_as::<")); } @@ -2309,7 +2420,7 @@ mod tests { true, PoolVisibility::Private, ); - let code = parse_and_format(&tokens); + let code = parse_and_format(&tokens).unwrap(); assert!(!code.contains(".bind(")); } @@ -2378,7 +2489,7 @@ mod tests { true, PoolVisibility::Private, ); - parse_and_format(&tokens) + parse_and_format(&tokens).unwrap() } #[test] @@ -2466,7 +2577,7 @@ mod tests { true, PoolVisibility::Private, ); - let code = parse_and_format(&tokens); + let code = parse_and_format(&tokens).unwrap(); // SELECT queries should use runtime query_as, not macro assert!(code.contains("query_as::<")); assert!(!code.contains("query_as!(")); @@ -2483,7 +2594,7 @@ mod tests { true, PoolVisibility::Private, ); - let code = parse_and_format(&tokens); + let code = parse_and_format(&tokens).unwrap(); // DELETE still uses query! macro assert!(code.contains("query!")); } @@ -2546,7 +2657,7 @@ mod tests { true, PoolVisibility::Private, ); - let code = parse_and_format(&tokens); + let code = parse_and_format(&tokens).unwrap(); assert!(code.contains("as_slice()")); } @@ -2561,7 +2672,7 @@ mod tests { true, PoolVisibility::Private, ); - let code = parse_and_format(&tokens); + let code = parse_and_format(&tokens).unwrap(); // Should have as_slice() for insert and update let count = code.matches("as_slice()").count(); assert!( @@ -2582,7 +2693,7 @@ mod tests { false, PoolVisibility::Private, ); - let code = parse_and_format(&tokens); + let code = parse_and_format(&tokens).unwrap(); // Runtime mode uses .bind() so no as_slice needed assert!(!code.contains("as_slice()")); } @@ -2612,7 +2723,7 @@ mod tests { true, PoolVisibility::Private, ); - let code = parse_and_format(&tokens); + let code = parse_and_format(&tokens).unwrap(); assert!( code.contains("as_slice()"), "Expected as_slice() in generated code:\n{}", @@ -2625,8 +2736,8 @@ mod tests { fn junction_entity() -> ParsedEntity { ParsedEntity { struct_name: "AnalysisRecord".to_string(), - table_name: "analysis.analysis__record".to_string(), - schema_name: None, + table_name: "analysis__record".to_string(), + schema_name: Some("analysis".to_string()), is_view: false, fields: vec![ make_field("record_id", "record_id", "uuid::Uuid", false, true), @@ -2656,7 +2767,7 @@ mod tests { ); assert!( code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id)"), - "Expected INSERT INTO clause:\n{}", + "Expected quoted INSERT INTO clause:\n{}", code ); assert!( @@ -2726,6 +2837,64 @@ mod tests { ); } + // --- composite PK + non-PK columns (MySQL) --- + + fn composite_pk_with_extra() -> ParsedEntity { + ParsedEntity { + struct_name: "OrderItems".to_string(), + table_name: "order_items".to_string(), + schema_name: None, + is_view: false, + fields: vec![ + make_field("order_id", "order_id", "i32", false, true), + make_field("product_id", "product_id", "i32", false, true), + make_field("qty", "qty", "i32", false, false), + ], + imports: vec![], + } + } + + #[test] + fn test_mysql_composite_pk_insert_uses_select_not_last_insert_id() { + let code = gen(&composite_pk_with_extra(), DatabaseKind::Mysql); + assert!( + !code.contains("LAST_INSERT_ID"), + "composite PK insert must not use LAST_INSERT_ID(), got:\n{}", + code + ); + assert!( + code.contains("SELECT *"), + "must SELECT the row back after INSERT, got:\n{}", + code + ); + assert!( + code.contains("WHERE order_id = ? AND product_id = ?"), + "SELECT must use bound composite PK values, got:\n{}", + code + ); + } + + #[test] + fn test_mysql_composite_pk_includes_pks_in_insert_params() { + let code = gen(&composite_pk_with_extra(), DatabaseKind::Mysql); + assert!( + code.contains("pub order_id"), + "InsertParams must expose composite PK column order_id, got:\n{}", + code + ); + assert!(code.contains("pub product_id")); + assert!(code.contains("pub qty")); + } + + #[test] + fn test_mysql_single_pk_insert_still_uses_last_insert_id() { + let code = gen(&standard_entity(), DatabaseKind::Mysql); + assert!( + code.contains("LAST_INSERT_ID"), + "single-PK MySQL insert should still rely on LAST_INSERT_ID()" + ); + } + // --- insert_many_transactionally --- #[test] diff --git a/crates/sqlx_gen/src/codegen/domain_gen.rs b/crates/sqlx_gen/src/codegen/domain_gen.rs index 65c119b..bd76e36 100644 --- a/crates/sqlx_gen/src/codegen/domain_gen.rs +++ b/crates/sqlx_gen/src/codegen/domain_gen.rs @@ -1,10 +1,10 @@ use std::collections::{BTreeSet, HashMap}; -use heck::ToUpperCamelCase; use proc_macro2::TokenStream; use quote::{format_ident, quote}; -use crate::cli::{DatabaseKind, TimeCrate}; +use crate::cli::{DatabaseKind, DomainStyle, TimeCrate}; +use crate::codegen::rust_type_name_for; use crate::introspect::{DomainInfo, SchemaInfo}; use crate::typemap; @@ -14,9 +14,28 @@ pub fn generate_domain( schema_info: &SchemaInfo, type_overrides: &HashMap, time_crate: TimeCrate, +) -> (TokenStream, BTreeSet) { + generate_domain_with_style( + domain, + db_kind, + schema_info, + type_overrides, + time_crate, + DomainStyle::Alias, + ) +} + +pub fn generate_domain_with_style( + domain: &DomainInfo, + db_kind: DatabaseKind, + schema_info: &SchemaInfo, + type_overrides: &HashMap, + time_crate: TimeCrate, + style: DomainStyle, ) -> (TokenStream, BTreeSet) { let mut imports = BTreeSet::new(); - let alias_name = format_ident!("{}", domain.name.to_upper_camel_case()); + let rust_name = rust_type_name_for(schema_info, &domain.schema_name, &domain.name); + let alias_name = format_ident!("{}", rust_name); let doc = format!( "Domain: {}.{} (base: {})", @@ -28,6 +47,7 @@ pub fn generate_domain( name: String::new(), data_type: domain.base_type.clone(), udt_name: domain.base_type.clone(), + udt_schema: None, is_nullable: false, is_primary_key: false, ordinal_position: 0, @@ -35,7 +55,8 @@ pub fn generate_domain( column_default: None, }; - let rust_type = typemap::map_column(&fake_col, db_kind, schema_info, type_overrides, time_crate); + let rust_type = + typemap::map_column(&fake_col, db_kind, schema_info, type_overrides, time_crate); if let Some(imp) = &rust_type.needs_import { imports.insert(imp.clone()); } @@ -46,10 +67,22 @@ pub fn generate_domain( }); let domain_doc = "sqlx_gen:kind=domain"; - let tokens = quote! { - #[doc = #doc] - #[doc = #domain_doc] - pub type #alias_name = #type_tokens; + let tokens = match style { + DomainStyle::Alias => quote! { + #[doc = #doc] + #[doc = #domain_doc] + pub type #alias_name = #type_tokens; + }, + DomainStyle::Newtype => { + imports.insert("use serde::{Serialize, Deserialize};".to_string()); + quote! { + #[doc = #doc] + #[doc = #domain_doc] + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)] + #[sqlx(transparent)] + pub struct #alias_name(pub #type_tokens); + } + } }; (tokens, imports) @@ -70,14 +103,29 @@ mod tests { fn gen(domain: &DomainInfo) -> (String, BTreeSet) { let schema = SchemaInfo::default(); - let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, &HashMap::new(), TimeCrate::Chrono); - (parse_and_format(&tokens), imports) + let (tokens, imports) = generate_domain( + domain, + DatabaseKind::Postgres, + &schema, + &HashMap::new(), + TimeCrate::Chrono, + ); + (parse_and_format(&tokens).unwrap(), imports) } - fn gen_with_overrides(domain: &DomainInfo, overrides: &HashMap) -> (String, BTreeSet) { + fn gen_with_overrides( + domain: &DomainInfo, + overrides: &HashMap, + ) -> (String, BTreeSet) { let schema = SchemaInfo::default(); - let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, overrides, TimeCrate::Chrono); - (parse_and_format(&tokens), imports) + let (tokens, imports) = generate_domain( + domain, + DatabaseKind::Postgres, + &schema, + overrides, + TimeCrate::Chrono, + ); + (parse_and_format(&tokens).unwrap(), imports) } #[test] @@ -153,4 +201,61 @@ mod tests { let (_, imports) = gen(&d); assert!(imports.iter().any(|i| i.contains("chrono"))); } + + // ========== DomainStyle::Newtype ========== + + fn gen_newtype(domain: &DomainInfo) -> (String, BTreeSet) { + let schema = SchemaInfo::default(); + let (tokens, imports) = generate_domain_with_style( + domain, + DatabaseKind::Postgres, + &schema, + &HashMap::new(), + TimeCrate::Chrono, + DomainStyle::Newtype, + ); + (parse_and_format(&tokens).unwrap(), imports) + } + + #[test] + fn test_newtype_emits_tuple_struct() { + let d = make_domain("email", "text"); + let (code, _) = gen_newtype(&d); + assert!( + code.contains("pub struct Email(pub String)"), + "newtype must wrap the base type in a tuple struct, got:\n{}", + code + ); + } + + #[test] + fn test_newtype_uses_transparent_derive() { + let d = make_domain("email", "text"); + let (code, _) = gen_newtype(&d); + assert!(code.contains("#[sqlx(transparent)]")); + assert!(code.contains("sqlx::Type")); + } + + #[test] + fn test_newtype_keeps_doc_comments() { + let d = make_domain("email", "text"); + let (code, _) = gen_newtype(&d); + assert!(code.contains("Domain: public.email (base: text)")); + assert!(code.contains("sqlx_gen:kind=domain")); + } + + #[test] + fn test_newtype_wraps_uuid_with_import() { + let d = make_domain("my_uuid", "uuid"); + let (code, imports) = gen_newtype(&d); + assert!(code.contains("pub struct MyUuid(pub Uuid)")); + assert!(imports.iter().any(|i| i.contains("uuid::Uuid"))); + } + + #[test] + fn test_newtype_does_not_emit_type_alias() { + let d = make_domain("email", "text"); + let (code, _) = gen_newtype(&d); + assert!(!code.contains("pub type Email")); + } } diff --git a/crates/sqlx_gen/src/codegen/entity_parser.rs b/crates/sqlx_gen/src/codegen/entity_parser.rs index fcc1f72..e911f27 100644 --- a/crates/sqlx_gen/src/codegen/entity_parser.rs +++ b/crates/sqlx_gen/src/codegen/entity_parser.rs @@ -45,9 +45,8 @@ pub struct ParsedEntity { /// Parse an entity struct from a `.rs` file on disk. pub fn parse_entity_file(path: &Path) -> crate::error::Result { let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?; - parse_entity_source(&source).map_err(|e| { - crate::error::Error::Config(format!("{}: {}", path.display(), e)) - }) + parse_entity_source(&source) + .map_err(|e| crate::error::Error::Config(format!("{}: {}", path.display(), e))) } /// Parse an entity struct from a Rust source string. @@ -127,13 +126,11 @@ fn extract_entity(item: &syn::ItemStruct) -> Result { let table_name = table_name.unwrap_or_else(|| struct_name.clone()); let fields = match &item.fields { - syn::Fields::Named(named) => { - named - .named - .iter() - .map(extract_field) - .collect::, _>>()? - } + syn::Fields::Named(named) => named + .named + .iter() + .map(extract_field) + .collect::, _>>()?, _ => return Err("Expected named fields".to_string()), }; @@ -149,7 +146,9 @@ fn extract_entity(item: &syn::ItemStruct) -> Result { /// Parse `#[sqlx_gen(kind = "...", schema = "...", table = "...")]` from struct attributes. /// Returns (kind, schema_name, table_name). -fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option, Option, Option) { +fn parse_sqlx_gen_struct_attrs( + attrs: &[syn::Attribute], +) -> (Option, Option, Option) { let mut kind = None; let mut schema_name = None; let mut table_name = None; @@ -194,14 +193,11 @@ fn extract_attr_value(tokens: &str, key: &str) -> Option { /// Extract a ParsedField from a syn::Field. fn extract_field(field: &syn::Field) -> Result { - let rust_name = field - .ident - .as_ref() - .ok_or("Unnamed field")? - .to_string(); + let rust_name = field.ident.as_ref().ok_or("Unnamed field")?.to_string(); let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone()); - let (is_primary_key, sql_type, is_sql_array, column_default) = parse_sqlx_gen_field_attrs(&field.attrs); + let (is_primary_key, sql_type, is_sql_array, column_default) = + parse_sqlx_gen_field_attrs(&field.attrs); let rust_type = field.ty.to_token_stream().to_string(); let (is_nullable, inner_type) = extract_option_type(&field.ty); @@ -226,7 +222,9 @@ fn extract_field(field: &syn::Field) -> Result { /// Parse `#[sqlx_gen(...)]` attributes on a field. /// Returns (is_primary_key, sql_type, is_sql_array, column_default). -fn parse_sqlx_gen_field_attrs(attrs: &[syn::Attribute]) -> (bool, Option, bool, Option) { +fn parse_sqlx_gen_field_attrs( + attrs: &[syn::Attribute], +) -> (bool, Option, bool, Option) { let mut is_pk = false; let mut sql_type = None; let mut is_array = false; @@ -674,7 +672,10 @@ mod tests { "#; let entity = parse_entity_source(source).unwrap(); let status = &entity.fields[1]; - assert_eq!(status.column_default, Some("'idle'::task_status".to_string())); + assert_eq!( + status.column_default, + Some("'idle'::task_status".to_string()) + ); } #[test] diff --git a/crates/sqlx_gen/src/codegen/enum_gen.rs b/crates/sqlx_gen/src/codegen/enum_gen.rs index 5ed27b2..60f5808 100644 --- a/crates/sqlx_gen/src/codegen/enum_gen.rs +++ b/crates/sqlx_gen/src/codegen/enum_gen.rs @@ -5,21 +5,76 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote}; use crate::cli::DatabaseKind; -use crate::codegen::imports_for_derives; -use crate::introspect::EnumInfo; +use crate::codegen::{imports_for_derives, rust_type_name_for}; +use crate::introspect::{EnumInfo, SchemaInfo}; + +/// Detect two SQL enum variants that collapse to the same Rust identifier after +/// `to_upper_camel_case` (e.g. `"foo bar"` and `"foo_bar"` both become `FooBar`). +/// Returns an error pointing at the offending pair so the user can rename one +/// side in the database before regenerating. +pub fn check_variant_collisions(enum_info: &EnumInfo) -> crate::error::Result<()> { + use std::collections::BTreeMap; + let mut seen: BTreeMap = BTreeMap::new(); + for v in &enum_info.variants { + let pascal = v.to_upper_camel_case(); + if let Some(prev) = seen.get(pascal.as_str()).copied() { + return Err(crate::error::Error::Config(format!( + "Enum '{}.{}': SQL variants '{}' and '{}' both map to Rust identifier '{}'. \ + Rename one of them in the database or use a custom mapping.", + enum_info.schema_name, enum_info.name, prev, v, pascal + ))); + } + seen.insert(pascal, v.as_str()); + } + Ok(()) +} pub fn generate_enum( enum_info: &EnumInfo, db_kind: DatabaseKind, extra_derives: &[String], +) -> (TokenStream, BTreeSet) { + // Backwards-compatible entry point — uses an empty SchemaInfo so the + // enum keeps its bare PascalCase name (no schema prefix). + generate_enum_with_schema(enum_info, db_kind, extra_derives, &SchemaInfo::default()) +} + +pub fn generate_enum_with_schema( + enum_info: &EnumInfo, + db_kind: DatabaseKind, + extra_derives: &[String], + schema_info: &SchemaInfo, ) -> (TokenStream, BTreeSet) { let mut imports = BTreeSet::new(); for imp in imports_for_derives(extra_derives) { imports.insert(imp); } - let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case()); + let rust_name = rust_type_name_for(schema_info, &enum_info.schema_name, &enum_info.name); + let enum_name = format_ident!("{}", rust_name); let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name); + // For non-default schemas, remind the user that sqlx 0.8 can only resolve + // unqualified type_name attributes — the connection must have the schema + // in its search_path. Emitted as a /// doc-comment so it shows up both in + // generated source and in rustdoc. + let search_path_doc = if db_kind == DatabaseKind::Postgres + && !crate::codegen::is_default_schema(&enum_info.schema_name) + { + let msg = format!( + "Lives in PostgreSQL schema `{schema}`. The sqlx connection \ + must include `{schema}` in its search_path so PG resolves the \ + unqualified `type_name = \"{name}\"` to this enum. Example:\n\ + \n\ + ```ignore\n\ + sqlx::query(\"SET search_path TO public, {schema}\")\n\ + ```", + schema = enum_info.schema_name, + name = enum_info.name, + ); + Some(msg) + } else { + None + }; imports.insert("use serde::{Serialize, Deserialize};".to_string()); imports.insert("use sqlx_gen::SqlxGen;".to_string()); @@ -38,13 +93,12 @@ pub fn generate_enum( derive_tokens.push(quote! { #ident }); } - // For PG, add #[sqlx(type_name = "...")] — schema-qualify for non-public schemas + // For PG, add #[sqlx(type_name = "...")] — always unqualified. + // sqlx 0.8's PgTypeInfo::with_name does NOT accept schema-qualified names; emitting + // "schema.type" causes runtime decode failures. The user is expected to set + // `search_path` on the connection so that PG resolves the unqualified type name. let type_attr = if db_kind == DatabaseKind::Postgres { - let pg_name = if enum_info.schema_name != "public" { - format!("{}.{}", enum_info.schema_name, enum_info.name) - } else { - enum_info.name.clone() - }; + let pg_name = &enum_info.name; quote! { #[sqlx(type_name = #pg_name)] } } else { quote! {} @@ -84,11 +138,22 @@ pub fn generate_enum( quote! {} }; + // Postgres arrays: `#[derive(sqlx::Type)]` with `#[sqlx(type_name = "x")]` + // already auto-generates `impl PgHasArrayType` returning `_x` in sqlx 0.8+. + // Emitting a second impl here triggers E0119 (conflicting implementations) + // in the user's crate. Leave the derive in charge. + let _ = db_kind; + let schema_name_str = &enum_info.schema_name; let enum_name_str = &enum_info.name; + let search_path_doc_tokens = match &search_path_doc { + Some(m) => quote! { #[doc = #m] }, + None => quote! {}, + }; let tokens = quote! { #[doc = #doc] + #search_path_doc_tokens #[derive(#(#derive_tokens),*)] #[sqlx_gen(kind = "enum", schema = #schema_name_str, name = #enum_name_str)] #type_attr @@ -118,7 +183,7 @@ mod tests { fn gen(info: &EnumInfo, db: DatabaseKind) -> String { let (tokens, _) = generate_enum(info, db, &[]); - parse_and_format(&tokens) + parse_and_format(&tokens).unwrap() } fn gen_with_derives( @@ -127,7 +192,7 @@ mod tests { derives: &[String], ) -> (String, BTreeSet) { let (tokens, imports) = generate_enum(info, db, derives); - (parse_and_format(&tokens), imports) + (parse_and_format(&tokens).unwrap(), imports) } // --- basic structure --- @@ -183,7 +248,60 @@ mod tests { } #[test] - fn test_postgres_non_public_schema_qualified_type_name() { + fn test_check_variant_collisions_detects_after_camel_case() { + let e = EnumInfo { + schema_name: "public".into(), + name: "weird".into(), + variants: vec!["foo bar".into(), "foo_bar".into()], + default_variant: None, + }; + let result = check_variant_collisions(&e); + assert!(result.is_err(), "must detect collision"); + let msg = result.unwrap_err().to_string(); + assert!( + msg.contains("FooBar"), + "error must mention conflicting Rust ident, got: {}", + msg + ); + assert!(msg.contains("foo bar") || msg.contains("foo_bar")); + } + + #[test] + fn test_check_variant_collisions_accepts_distinct_variants() { + let e = make_enum("status", vec!["active", "inactive"]); + assert!(check_variant_collisions(&e).is_ok()); + } + + #[test] + fn test_check_variant_collisions_accepts_single_variant() { + let e = make_enum("status", vec!["only"]); + assert!(check_variant_collisions(&e).is_ok()); + } + + #[test] + fn test_does_not_emit_manual_pg_has_array_type_impl() { + // Regression for E0119 — `#[derive(sqlx::Type)]` already provides this + // impl when `type_name` is set, so emitting our own conflicted. + for db in [ + DatabaseKind::Postgres, + DatabaseKind::Mysql, + DatabaseKind::Sqlite, + ] { + let e = make_enum("status", vec!["a", "b"]); + let code = gen(&e, db); + assert!( + !code.contains("PgHasArrayType"), + "{:?}: must not emit a manual PgHasArrayType impl, got:\n{}", + db, + code + ); + } + } + + #[test] + fn test_postgres_non_public_schema_type_name_is_unqualified() { + // Regression: previously emitted "auth.role" which crashes sqlx 0.8 at runtime + // (PgTypeInfo::with_name does not accept schema-qualified names). let e = EnumInfo { schema_name: "auth".to_string(), name: "role".to_string(), @@ -191,8 +309,17 @@ mod tests { default_variant: None, }; let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]); - let code = parse_and_format(&tokens); - assert!(code.contains("sqlx(type_name = \"auth.role\")")); + let code = parse_and_format(&tokens).unwrap(); + assert!( + code.contains("sqlx(type_name = \"role\")"), + "type_name must be unqualified for sqlx 0.8 compatibility, got:\n{}", + code + ); + assert!( + !code.contains("\"auth.role\""), + "type_name must NOT include schema; got:\n{}", + code + ); } #[test] @@ -204,6 +331,31 @@ mod tests { assert!(!code.contains("type_name = \"public.status\"")); } + #[test] + fn test_mysql_inline_enum_emits_rename_for_lowercase_variants() { + // Inline MySQL ENUM('active', 'inactive') → Rust variants are PascalCase + // and need #[sqlx(rename)] so encode/decode hits the SQL text values. + let e = make_enum("status", vec!["active", "inactive"]); + let code = gen(&e, DatabaseKind::Mysql); + assert!( + code.contains("sqlx(rename = \"active\")"), + "MySQL inline ENUM variant must carry rename for round-trip:\n{}", + code + ); + assert!(code.contains("sqlx(rename = \"inactive\")")); + // type_name does NOT exist for MySQL — only PG-native enums need it. + assert!(!code.contains("type_name")); + } + + #[test] + fn test_mysql_inline_enum_preserves_case_sensitive_variants() { + let e = make_enum("priority", vec!["LOW", "HIGH"]); + let code = gen(&e, DatabaseKind::Mysql); + // PascalCase("LOW") = "Low" → rename required so SQL sees "LOW" + assert!(code.contains("sqlx(rename = \"LOW\")")); + assert!(code.contains("sqlx(rename = \"HIGH\")")); + } + #[test] fn test_mysql_no_type_name() { let e = make_enum("status", vec!["a"]); @@ -336,7 +488,11 @@ mod tests { let e = EnumInfo { schema_name: "public".to_string(), name: "task_status".to_string(), - variants: vec!["idle".to_string(), "running".to_string(), "done".to_string()], + variants: vec![ + "idle".to_string(), + "running".to_string(), + "done".to_string(), + ], default_variant: Some("idle".to_string()), }; let code = gen(&e, DatabaseKind::Postgres); @@ -377,14 +533,19 @@ mod tests { #[test] fn test_public_schema_full_output() { - let e = make_enum_in_schema("public", "order_status", vec!["pending", "shipped", "delivered"]); + let e = make_enum_in_schema( + "public", + "order_status", + vec!["pending", "shipped", "delivered"], + ); let code = gen(&e, DatabaseKind::Postgres); assert!(code.contains("Enum: public.order_status")); assert!(code.contains("pub enum OrderStatus")); assert!(code.contains("sqlx(type_name = \"order_status\")")); assert!(!code.contains("sqlx(type_name = \"public.order_status\")")); - assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"order_status\")")); + assert!(code + .contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"order_status\")")); assert!(code.contains("Pending")); assert!(code.contains("Shipped")); assert!(code.contains("Delivered")); @@ -392,14 +553,20 @@ mod tests { #[test] fn test_named_schema_full_output() { - let e = make_enum_in_schema("analysis", "toolcall_status", vec!["PENDING", "RUNNING", "DONE"]); + let e = make_enum_in_schema( + "analysis", + "toolcall_status", + vec!["PENDING", "RUNNING", "DONE"], + ); let code = gen(&e, DatabaseKind::Postgres); assert!(code.contains("Enum: analysis.toolcall_status")); assert!(code.contains("pub enum ToolcallStatus")); - assert!(code.contains("sqlx(type_name = \"analysis.toolcall_status\")")); - assert!(!code.contains("sqlx(type_name = \"toolcall_status\")")); - assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"analysis\", name = \"toolcall_status\")")); + assert!(code.contains("sqlx(type_name = \"toolcall_status\")")); + assert!(!code.contains("\"analysis.toolcall_status\"")); + assert!(code.contains( + "sqlx_gen(kind = \"enum\", schema = \"analysis\", name = \"toolcall_status\")" + )); assert!(code.contains("Pending")); assert!(code.contains("Running")); assert!(code.contains("Done")); @@ -410,12 +577,17 @@ mod tests { let e = EnumInfo { schema_name: "billing".to_string(), name: "payment_status".to_string(), - variants: vec!["pending".to_string(), "paid".to_string(), "refunded".to_string()], + variants: vec![ + "pending".to_string(), + "paid".to_string(), + "refunded".to_string(), + ], default_variant: Some("pending".to_string()), }; let code = gen(&e, DatabaseKind::Postgres); - assert!(code.contains("sqlx(type_name = \"billing.payment_status\")")); + assert!(code.contains("sqlx(type_name = \"payment_status\")")); + assert!(!code.contains("\"billing.payment_status\"")); assert!(code.contains("impl Default for PaymentStatus")); assert!(code.contains("Self::Pending")); } @@ -425,7 +597,8 @@ mod tests { let e = make_enum_in_schema("audit", "log_level", vec!["info", "warn_high", "CRITICAL"]); let code = gen(&e, DatabaseKind::Postgres); - assert!(code.contains("sqlx(type_name = \"audit.log_level\")")); + assert!(code.contains("sqlx(type_name = \"log_level\")")); + assert!(!code.contains("\"audit.log_level\"")); assert!(code.contains("sqlx(rename = \"info\")")); assert!(code.contains("sqlx(rename = \"warn_high\")")); assert!(code.contains("WarnHigh")); diff --git a/crates/sqlx_gen/src/codegen/identifiers.rs b/crates/sqlx_gen/src/codegen/identifiers.rs new file mode 100644 index 0000000..837f54a --- /dev/null +++ b/crates/sqlx_gen/src/codegen/identifiers.rs @@ -0,0 +1,344 @@ +use crate::cli::DatabaseKind; + +/// SQL keywords reserved across at least one of Postgres / MySQL / SQLite. +/// Sorted so we can binary-search in [`is_reserved_keyword`]. Conservative +/// list — when in doubt, an identifier matching one of these gets quoted. +const SQL_RESERVED: &[&str] = &[ + "ABORT", + "ALL", + "ALTER", + "AND", + "ANY", + "AS", + "ASC", + "AUTHORIZATION", + "BEFORE", + "BEGIN", + "BETWEEN", + "BOTH", + "BY", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "COMMIT", + "CONSTRAINT", + "CREATE", + "CROSS", + "CURRENT", + "CURRENT_DATE", + "CURRENT_ROLE", + "CURRENT_SCHEMA", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "DEFAULT", + "DEFERRABLE", + "DELETE", + "DESC", + "DISTINCT", + "DO", + "DROP", + "ELSE", + "END", + "EXCEPT", + "EXISTS", + "FALSE", + "FETCH", + "FOR", + "FOREIGN", + "FROM", + "FULL", + "GRANT", + "GROUP", + "HAVING", + "IF", + "IN", + "INDEX", + "INNER", + "INSERT", + "INTERSECT", + "INTO", + "IS", + "JOIN", + "KEY", + "LATERAL", + "LEADING", + "LEFT", + "LIKE", + "LIMIT", + "LOCALTIME", + "LOCALTIMESTAMP", + "NATURAL", + "NOT", + "NULL", + "OF", + "OFFSET", + "ON", + "ONLY", + "OR", + "ORDER", + "OUTER", + "OVERLAPS", + "PLACING", + "PRIMARY", + "REFERENCES", + "RETURNING", + "RIGHT", + "ROLLBACK", + "SCHEMA", + "SELECT", + "SESSION_USER", + "SET", + "SIMILAR", + "SOME", + "SYMMETRIC", + "TABLE", + "THEN", + "TO", + "TRAILING", + "TRIGGER", + "TRUE", + "UNION", + "UNIQUE", + "UPDATE", + "USER", + "USING", + "VALUES", + "VARIADIC", + "VIEW", + "WHEN", + "WHERE", + "WINDOW", + "WITH", +]; + +/// Case-insensitive reserved-word lookup. The list is sorted, so binary_search +/// keeps this O(log n). +fn is_reserved_keyword(name: &str) -> bool { + let upper = name.to_ascii_uppercase(); + SQL_RESERVED.binary_search(&upper.as_str()).is_ok() +} + +/// True when `name` is safe to emit unquoted in SQL for the given dialect. +/// +/// "Safe" means: starts with a letter or underscore (lowercase only — PG +/// folds unquoted idents to lowercase, so uppercase letters force a quote), +/// then ASCII letters / digits / underscores, and is not a reserved keyword. +pub fn is_safe_unquoted(name: &str, _db: DatabaseKind) -> bool { + if name.is_empty() { + return false; + } + let bytes = name.as_bytes(); + let first = bytes[0]; + let first_ok = first == b'_' || first.is_ascii_lowercase(); + if !first_ok { + return false; + } + for &b in &bytes[1..] { + let ok = b == b'_' || b.is_ascii_lowercase() || b.is_ascii_digit(); + if !ok { + return false; + } + } + !is_reserved_keyword(name) +} + +/// Quote a SQL identifier (table/column/schema) per database dialect, but +/// only when quoting is syntactically required. Trivially-safe identifiers +/// pass through untouched. +pub fn quote_ident(name: &str, db: DatabaseKind) -> String { + if is_safe_unquoted(name, db) { + name.to_string() + } else { + quote_ident_always(name, db) + } +} + +/// Always quote, regardless of safety. Doubles internal quote characters. +pub fn quote_ident_always(name: &str, db: DatabaseKind) -> String { + match db { + DatabaseKind::Mysql => format!("`{}`", name.replace('`', "``")), + DatabaseKind::Postgres | DatabaseKind::Sqlite => { + format!("\"{}\"", name.replace('"', "\"\"")) + } + } +} + +/// Quote a qualified table reference (`schema.table`) per dialect, or the +/// bare table when no schema is provided. Each part is conditionally quoted. +pub fn quote_qualified(schema: Option<&str>, table: &str, db: DatabaseKind) -> String { + match schema { + Some(s) => format!("{}.{}", quote_ident(s, db), quote_ident(table, db)), + None => quote_ident(table, db), + } +} + +/// True if `name` is a safe SQL identifier candidate (alphanumeric + underscore, +/// non-empty, does not start with a digit). Loosely the same predicate as +/// [`is_safe_unquoted`] minus the case sensitivity and reserved-word check — +/// kept as a separate helper because it's a generic "could this be a safe +/// identifier" question used by filename validation. +pub fn is_safe_ident(name: &str) -> bool { + !name.is_empty() + && name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') + && !name.starts_with(|c: char| c.is_ascii_digit()) +} + +#[cfg(test)] +mod tests { + use super::*; + + // ---------- conditional quote_ident ---------- + + #[test] + fn safe_identifier_pg_not_quoted() { + assert_eq!(quote_ident("users", DatabaseKind::Postgres), "users"); + assert_eq!(quote_ident("agent_id", DatabaseKind::Postgres), "agent_id"); + assert_eq!( + quote_ident("agent__connector", DatabaseKind::Postgres), + "agent__connector" + ); + } + + #[test] + fn safe_identifier_mysql_not_quoted() { + assert_eq!(quote_ident("users", DatabaseKind::Mysql), "users"); + } + + #[test] + fn uppercase_identifier_pg_quoted() { + // PG folds unquoted identifiers to lowercase; uppercase needs quoting + // to preserve the case of the actual table name. + assert_eq!(quote_ident("Users", DatabaseKind::Postgres), "\"Users\""); + } + + #[test] + fn reserved_word_quoted_in_pg() { + assert_eq!(quote_ident("select", DatabaseKind::Postgres), "\"select\""); + assert_eq!(quote_ident("user", DatabaseKind::Postgres), "\"user\""); + assert_eq!(quote_ident("order", DatabaseKind::Postgres), "\"order\""); + } + + #[test] + fn reserved_word_quoted_in_mysql() { + assert_eq!(quote_ident("select", DatabaseKind::Mysql), "`select`"); + } + + #[test] + fn identifier_with_dash_quoted() { + assert_eq!( + quote_ident("user-id", DatabaseKind::Postgres), + "\"user-id\"" + ); + } + + #[test] + fn identifier_starting_with_digit_quoted() { + assert_eq!(quote_ident("123abc", DatabaseKind::Postgres), "\"123abc\""); + } + + #[test] + fn empty_identifier_quoted() { + // Empty isn't valid SQL but we don't want to drop the quotes and emit + // bare whitespace either. + assert_eq!(quote_ident("", DatabaseKind::Postgres), "\"\""); + } + + #[test] + fn injection_attempt_quoted_and_escaped() { + assert_eq!( + quote_ident("user\"; DROP TABLE x; --", DatabaseKind::Postgres), + "\"user\"\"; DROP TABLE x; --\"" + ); + } + + // ---------- quote_ident_always ---------- + + #[test] + fn always_quote_safe_identifier_pg() { + assert_eq!( + quote_ident_always("users", DatabaseKind::Postgres), + "\"users\"" + ); + } + + #[test] + fn always_quote_safe_identifier_mysql() { + assert_eq!(quote_ident_always("users", DatabaseKind::Mysql), "`users`"); + } + + #[test] + fn always_quote_escapes_internal_backtick() { + assert_eq!(quote_ident_always("ev`il", DatabaseKind::Mysql), "`ev``il`"); + } + + // ---------- quote_qualified ---------- + + #[test] + fn qualified_safe_idents_not_quoted() { + assert_eq!( + quote_qualified(Some("agent"), "agent_connector", DatabaseKind::Postgres), + "agent.agent_connector" + ); + assert_eq!( + quote_qualified(Some("app"), "users", DatabaseKind::Mysql), + "app.users" + ); + } + + #[test] + fn qualified_with_reserved_schema_quoted() { + assert_eq!( + quote_qualified(Some("user"), "items", DatabaseKind::Postgres), + "\"user\".items" + ); + } + + #[test] + fn qualified_without_schema() { + assert_eq!(quote_qualified(None, "users", DatabaseKind::Mysql), "users"); + } + + // ---------- is_safe_ident (filename helper) ---------- + + #[test] + fn safe_ident_rejects_dash() { + assert!(!is_safe_ident("user-id")); + } + + #[test] + fn safe_ident_rejects_leading_digit() { + assert!(!is_safe_ident("123abc")); + } + + #[test] + fn safe_ident_rejects_empty() { + assert!(!is_safe_ident("")); + } + + #[test] + fn safe_ident_accepts_underscore_prefix() { + assert!(is_safe_ident("_private")); + } + + #[test] + fn safe_ident_accepts_mixed_case() { + assert!(is_safe_ident("UserAccount2")); + } + + // ---------- reserved list sanity ---------- + + #[test] + fn reserved_list_is_sorted() { + for pair in SQL_RESERVED.windows(2) { + assert!( + pair[0] < pair[1], + "SQL_RESERVED must be sorted; '{}' >= '{}'", + pair[0], + pair[1] + ); + } + } +} diff --git a/crates/sqlx_gen/src/codegen/mod.rs b/crates/sqlx_gen/src/codegen/mod.rs index 193a58e..8dcac73 100644 --- a/crates/sqlx_gen/src/codegen/mod.rs +++ b/crates/sqlx_gen/src/codegen/mod.rs @@ -3,6 +3,7 @@ pub mod crud_gen; pub mod domain_gen; pub mod entity_parser; pub mod enum_gen; +pub mod identifiers; pub mod struct_gen; use std::collections::{BTreeSet, HashMap}; @@ -15,11 +16,11 @@ use crate::introspect::SchemaInfo; /// Rust reserved keywords that cannot be used as identifiers. const RUST_KEYWORDS: &[&str] = &[ - "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", - "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", - "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true", - "type", "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do", - "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual", + "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern", + "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", + "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", + "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do", "final", + "macro", "override", "priv", "try", "typeof", "unsized", "virtual", ]; /// Returns true if the given name is a Rust reserved keyword. @@ -71,6 +72,85 @@ pub fn is_default_schema(schema: &str) -> bool { DEFAULT_SCHEMAS.contains(&schema) } +/// Compute the Rust identifier for an enum / composite / domain. +/// +/// When the same SQL name is declared in more than one schema (e.g. both +/// `auth.role` and `billing.role` exist), the non-default-schema variants get +/// a `SchemaName` PascalCase prefix to avoid Rust-level identifier collisions. +/// Otherwise the bare PascalCase of the SQL name is used. +pub fn rust_type_name_for(schema_info: &SchemaInfo, schema: &str, name: &str) -> String { + use heck::ToUpperCamelCase; + if type_name_has_cross_schema_collision(schema_info, name) && !is_default_schema(schema) { + format!( + "{}{}", + schema.to_upper_camel_case(), + name.to_upper_camel_case() + ) + } else { + name.to_upper_camel_case() + } +} + +/// Compute the schemas that must appear in PostgreSQL's `search_path` for +/// the generated code to resolve every emitted unqualified `type_name`. +/// +/// Returns the deduplicated, sorted list of non-default schemas hosting +/// enums/composites/domains in `schema_info`. The caller can feed this into +/// the pool's connect-hook, e.g.: +/// +/// ```ignore +/// let schemas = sqlx_gen::codegen::required_pg_search_path(&info).join(", "); +/// sqlx::query(&format!("SET search_path TO public, {}", schemas)) +/// .execute(&pool).await?; +/// ``` +pub fn required_pg_search_path(schema_info: &SchemaInfo) -> Vec { + let mut schemas: std::collections::BTreeSet = std::collections::BTreeSet::new(); + for e in &schema_info.enums { + if !is_default_schema(&e.schema_name) { + schemas.insert(e.schema_name.clone()); + } + } + for c in &schema_info.composite_types { + if !is_default_schema(&c.schema_name) { + schemas.insert(c.schema_name.clone()); + } + } + for d in &schema_info.domains { + if !is_default_schema(&d.schema_name) { + schemas.insert(d.schema_name.clone()); + } + } + schemas.into_iter().collect() +} + +/// True when the SQL `name` is declared by enums / composites / domains living +/// in more than one schema. +pub fn type_name_has_cross_schema_collision(schema_info: &SchemaInfo, name: &str) -> bool { + let mut schemas: std::collections::BTreeSet<&str> = std::collections::BTreeSet::new(); + schemas.extend( + schema_info + .enums + .iter() + .filter(|e| e.name == name) + .map(|e| e.schema_name.as_str()), + ); + schemas.extend( + schema_info + .composite_types + .iter() + .filter(|c| c.name == name) + .map(|c| c.schema_name.as_str()), + ); + schemas.extend( + schema_info + .domains + .iter() + .filter(|d| d.name == name) + .map(|d| d.schema_name.as_str()), + ); + schemas.len() > 1 +} + /// Build a module name, prefixing with schema only when the name collides /// (same table name exists in multiple schemas). pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String { @@ -85,10 +165,14 @@ pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: boo fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> { let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new(); for t in &schema_info.tables { - seen.entry(t.name.as_str()).or_default().insert(t.schema_name.as_str()); + seen.entry(t.name.as_str()) + .or_default() + .insert(t.schema_name.as_str()); } for v in &schema_info.views { - seen.entry(v.name.as_str()).or_default().insert(v.schema_name.as_str()); + seen.entry(v.name.as_str()) + .or_default() + .insert(v.schema_name.as_str()); } seen.into_iter() .filter(|(_, schemas)| schemas.len() > 1) @@ -113,7 +197,29 @@ pub fn generate( type_overrides: &HashMap, single_file: bool, time_crate: TimeCrate, -) -> Vec { +) -> crate::error::Result> { + generate_with_domain_style( + schema_info, + db_kind, + extra_derives, + type_overrides, + single_file, + time_crate, + crate::cli::DomainStyle::Alias, + ) +} + +/// Same as [`generate`] but lets the caller pick how Postgres domains are +/// rendered (alias vs newtype). +pub fn generate_with_domain_style( + schema_info: &SchemaInfo, + db_kind: DatabaseKind, + extra_derives: &[String], + type_overrides: &HashMap, + single_file: bool, + time_crate: TimeCrate, + domain_style: crate::cli::DomainStyle, +) -> crate::error::Result> { let mut files = Vec::new(); // Detect table/view names that appear in multiple schemas (collisions) @@ -121,11 +227,22 @@ pub fn generate( // Generate struct files for each table for table in &schema_info.tables { - let (tokens, imports) = - struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides, false, time_crate); + let (tokens, imports) = struct_gen::generate_struct( + table, + db_kind, + schema_info, + extra_derives, + type_overrides, + false, + time_crate, + ); let imports = filter_imports(&imports, single_file); - let code = format_tokens_with_imports(&tokens, &imports); - let module_name = build_module_name(&table.schema_name, &table.name, colliding_names.contains(table.name.as_str())); + let code = format_tokens_with_imports(&tokens, &imports)?; + let module_name = build_module_name( + &table.schema_name, + &table.name, + colliding_names.contains(table.name.as_str()), + ); files.push(GeneratedFile { filename: format!("{}.rs", module_name), origin: None, @@ -135,11 +252,22 @@ pub fn generate( // Generate struct files for each view for view in &schema_info.views { - let (tokens, imports) = - struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides, true, time_crate); + let (tokens, imports) = struct_gen::generate_struct( + view, + db_kind, + schema_info, + extra_derives, + type_overrides, + true, + time_crate, + ); let imports = filter_imports(&imports, single_file); - let code = format_tokens_with_imports(&tokens, &imports); - let module_name = build_module_name(&view.schema_name, &view.name, colliding_names.contains(view.name.as_str())); + let code = format_tokens_with_imports(&tokens, &imports)?; + let module_name = build_module_name( + &view.schema_name, + &view.name, + colliding_names.contains(view.name.as_str()), + ); files.push(GeneratedFile { filename: format!("{}.rs", module_name), origin: None, @@ -155,14 +283,16 @@ pub fn generate( // Enrich enums with default variants extracted from column defaults let enum_defaults = extract_enum_defaults(schema_info); for enum_info in &schema_info.enums { + enum_gen::check_variant_collisions(enum_info)?; let mut enriched = enum_info.clone(); if enriched.default_variant.is_none() { if let Some(default) = enum_defaults.get(&enum_info.name) { enriched.default_variant = Some(default.clone()); } } - let (tokens, imports) = enum_gen::generate_enum(&enriched, db_kind, extra_derives); - types_blocks.push(format_tokens(&tokens)); + let (tokens, imports) = + enum_gen::generate_enum_with_schema(&enriched, db_kind, extra_derives, schema_info); + types_blocks.push(format_tokens(&tokens)?); types_imports.extend(imports); } @@ -175,22 +305,25 @@ pub fn generate( type_overrides, time_crate, ); - types_blocks.push(format_tokens(&tokens)); + types_blocks.push(format_tokens(&tokens)?); types_imports.extend(imports); } for domain in &schema_info.domains { - let (tokens, imports) = - domain_gen::generate_domain(domain, db_kind, schema_info, type_overrides, time_crate); - types_blocks.push(format_tokens(&tokens)); + let (tokens, imports) = domain_gen::generate_domain_with_style( + domain, + db_kind, + schema_info, + type_overrides, + time_crate, + domain_style, + ); + types_blocks.push(format_tokens(&tokens)?); types_imports.extend(imports); } if !types_blocks.is_empty() { - let import_lines: String = types_imports - .iter() - .map(|i| format!("{}\n", i)) - .collect(); + let import_lines: String = types_imports.iter().map(|i| format!("{}\n", i)).collect(); let body = types_blocks.join("\n"); let code = if import_lines.is_empty() { body @@ -204,7 +337,7 @@ pub fn generate( }); } - files + Ok(files) } /// Extract default variant values for enums by scanning column defaults across all tables and views. @@ -246,16 +379,12 @@ fn extract_enum_defaults(schema_info: &SchemaInfo) -> HashMap { /// Handles formats like `'idle'::task_status`, `'idle'::public.task_status`. fn parse_pg_enum_default(default_expr: &str) -> Option { // Pattern: 'value'::some_type - let stripped = default_expr.trim(); - if stripped.starts_with('\'') { - if let Some(end_quote) = stripped[1..].find('\'') { - let value = &stripped[1..1 + end_quote]; - // Verify there's a :: cast after the closing quote - let rest = &stripped[2 + end_quote..]; - if rest.starts_with("::") { - return Some(value.to_string()); - } - } + let after_opening = default_expr.trim().strip_prefix('\'')?; + let end_quote = after_opening.find('\'')?; + let value = &after_opening[..end_quote]; + let rest = &after_opening[end_quote + 1..]; + if rest.starts_with("::") { + return Some(value.to_string()); } None } @@ -307,32 +436,45 @@ pub fn detect_tab_spaces(start_dir: &Path) -> usize { /// Parse and format a TokenStream via prettyplease, then post-process spacing. /// `tab_spaces` controls how many spaces per indentation level for SQL inside raw strings. -pub(crate) fn parse_and_format(tokens: &TokenStream) -> String { +pub(crate) fn parse_and_format(tokens: &TokenStream) -> crate::error::Result { parse_and_format_with_tab_spaces(tokens, 4) } -pub(crate) fn parse_and_format_with_tab_spaces(tokens: &TokenStream, tab_spaces: usize) -> String { - let file = syn::parse2::(tokens.clone()).unwrap_or_else(|e| { - log::error!("Failed to parse generated code: {}", e); - log::error!("This is a bug in sqlx-gen. Raw tokens:\n {}", tokens); - std::process::exit(1); - }); +pub(crate) fn parse_and_format_with_tab_spaces( + tokens: &TokenStream, + tab_spaces: usize, +) -> crate::error::Result { + let file = syn::parse2::(tokens.clone()).map_err(|e| { + crate::error::Error::Config(format!( + "Internal sqlx-gen bug: failed to parse generated code: {}. \ + Raw tokens:\n {}\n\ + Please report this with the input schema.", + e, tokens + )) + })?; let raw = prettyplease::unparse(&file); let raw = indent_multiline_raw_strings(&raw, tab_spaces); - add_blank_lines_between_items(&raw) + Ok(add_blank_lines_between_items(&raw)) } /// Format a single TokenStream block (no imports). -pub(crate) fn format_tokens(tokens: &TokenStream) -> String { +pub(crate) fn format_tokens(tokens: &TokenStream) -> crate::error::Result { parse_and_format(tokens) } -pub fn format_tokens_with_imports(tokens: &TokenStream, imports: &BTreeSet) -> String { +pub fn format_tokens_with_imports( + tokens: &TokenStream, + imports: &BTreeSet, +) -> crate::error::Result { format_tokens_with_imports_and_tab_spaces(tokens, imports, 4) } -pub fn format_tokens_with_imports_and_tab_spaces(tokens: &TokenStream, imports: &BTreeSet, tab_spaces: usize) -> String { - let formatted = parse_and_format_with_tab_spaces(tokens, tab_spaces); +pub fn format_tokens_with_imports_and_tab_spaces( + tokens: &TokenStream, + imports: &BTreeSet, + tab_spaces: usize, +) -> crate::error::Result { + let formatted = parse_and_format_with_tab_spaces(tokens, tab_spaces)?; let used_imports: Vec<&String> = imports .iter() @@ -340,13 +482,10 @@ pub fn format_tokens_with_imports_and_tab_spaces(tokens: &TokenStream, imports: .collect(); if used_imports.is_empty() { - formatted + Ok(formatted) } else { - let import_lines: String = used_imports - .iter() - .map(|i| format!("{}\n", i)) - .collect(); - format!("{}\n\n{}", import_lines.trim_end(), formatted) + let import_lines: String = used_imports.iter().map(|i| format!("{}\n", i)).collect(); + Ok(format!("{}\n\n{}", import_lines.trim_end(), formatted)) } } @@ -383,10 +522,6 @@ fn is_import_used(import: &str, code: &str) -> bool { true } -/// Post-process formatted code to: -/// - Add blank lines between enum variants with `#[sqlx(rename` -/// - Add blank lines between top-level items (structs, impls) -/// - Add blank lines between logical blocks inside async methods /// Indent the content of multi-line raw string literals (`r#"..."#`) so SQL /// reads naturally in generated code. All SQL raw strings live inside `impl` /// methods, so content is indented at a fixed 2-level depth and relative @@ -397,7 +532,7 @@ fn indent_multiline_raw_strings(code: &str, tab_spaces: usize) -> String { // bake the right indentation at generation time. // The closing "# aligns with the r#" argument level (3 indent levels deep), // and SQL content gets one extra level beyond that. - let close_indent = 4 + tab_spaces; // impl(4) + fn_arg(tab) + let close_indent = 4 + tab_spaces; // impl(4) + fn_arg(tab) let sql_indent = 4 + 2 * tab_spaces; // impl(4) + fn_arg(tab) + sql(tab) let lines: Vec<&str> = code.lines().collect(); @@ -489,14 +624,14 @@ fn add_blank_lines_between_items(code: &str) -> String { let prev_is_await_end = prev.ends_with(".await?;") || prev.ends_with(".await?") || (prev.ends_with(';') && prev.contains(".unwrap_or(")); - if prev_is_await_end - && (trimmed.starts_with("let ") || trimmed.starts_with("Ok(")) - { + if prev_is_await_end && (trimmed.starts_with("let ") || trimmed.starts_with("Ok(")) { result.push(""); } // Separate a sqlx query `let` from preceding simple `let` assignments - if trimmed.starts_with("let ") && trimmed.contains("sqlx::") - && prev.starts_with("let ") && !prev.contains("sqlx::") + if trimmed.starts_with("let ") + && trimmed.contains("sqlx::") + && prev.starts_with("let ") + && !prev.contains("sqlx::") { result.push(""); } @@ -644,7 +779,10 @@ mod tests { #[test] fn test_build_collision_normalizes_double_underscore() { - assert_eq!(build_module_name("billing", "agent__connector", true), "billing_agent_connector"); + assert_eq!( + build_module_name("billing", "agent__connector", true), + "billing_agent_connector" + ); } // ========== is_default_schema ========== @@ -751,6 +889,120 @@ mod tests { assert_eq!(result, input); } + // ========== rust_type_name_for / cross-schema collisions ========== + + fn schema_with_two_role_enums() -> SchemaInfo { + SchemaInfo { + enums: vec![ + crate::introspect::EnumInfo { + schema_name: "auth".into(), + name: "role".into(), + variants: vec!["admin".into(), "user".into()], + default_variant: None, + }, + crate::introspect::EnumInfo { + schema_name: "billing".into(), + name: "role".into(), + variants: vec!["payer".into(), "payee".into()], + default_variant: None, + }, + ], + ..Default::default() + } + } + + #[test] + fn rust_type_name_prefixes_schema_on_cross_schema_collision() { + let s = schema_with_two_role_enums(); + assert_eq!(rust_type_name_for(&s, "auth", "role"), "AuthRole"); + assert_eq!(rust_type_name_for(&s, "billing", "role"), "BillingRole"); + } + + #[test] + fn rust_type_name_keeps_bare_name_when_unique() { + let s = SchemaInfo { + enums: vec![crate::introspect::EnumInfo { + schema_name: "auth".into(), + name: "role".into(), + variants: vec!["admin".into()], + default_variant: None, + }], + ..Default::default() + }; + assert_eq!(rust_type_name_for(&s, "auth", "role"), "Role"); + } + + #[test] + fn required_search_path_collects_non_default_schemas() { + let s = SchemaInfo { + enums: vec![ + crate::introspect::EnumInfo { + schema_name: "auth".into(), + name: "role".into(), + variants: vec!["x".into()], + default_variant: None, + }, + crate::introspect::EnumInfo { + schema_name: "public".into(), + name: "status".into(), + variants: vec!["y".into()], + default_variant: None, + }, + ], + composite_types: vec![crate::introspect::CompositeTypeInfo { + schema_name: "billing".into(), + name: "addr".into(), + fields: vec![], + }], + domains: vec![crate::introspect::DomainInfo { + schema_name: "auth".into(), + name: "email".into(), + base_type: "text".into(), + }], + ..Default::default() + }; + // Sorted, deduplicated, public excluded. + assert_eq!(required_pg_search_path(&s), vec!["auth", "billing"]); + } + + #[test] + fn required_search_path_empty_when_only_default_schema() { + let s = SchemaInfo { + enums: vec![crate::introspect::EnumInfo { + schema_name: "public".into(), + name: "status".into(), + variants: vec!["y".into()], + default_variant: None, + }], + ..Default::default() + }; + assert!(required_pg_search_path(&s).is_empty()); + } + + #[test] + fn rust_type_name_default_schema_keeps_bare_name_even_on_collision() { + let s = SchemaInfo { + enums: vec![ + crate::introspect::EnumInfo { + schema_name: "public".into(), + name: "role".into(), + variants: vec!["a".into()], + default_variant: None, + }, + crate::introspect::EnumInfo { + schema_name: "auth".into(), + name: "role".into(), + variants: vec!["b".into()], + default_variant: None, + }, + ], + ..Default::default() + }; + // public stays "Role"; auth gets the schema prefix to break the tie. + assert_eq!(rust_type_name_for(&s, "public", "role"), "Role"); + assert_eq!(rust_type_name_for(&s, "auth", "role"), "AuthRole"); + } + // ========== filter_imports ========== #[test] @@ -806,6 +1058,7 @@ mod tests { is_primary_key: false, ordinal_position: 0, schema_name: "public".to_string(), + udt_schema: None, column_default: None, } } @@ -813,7 +1066,15 @@ mod tests { #[test] fn test_generate_empty_schema() { let schema = SchemaInfo::default(); - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert!(files.is_empty()); } @@ -823,7 +1084,15 @@ mod tests { tables: vec![make_table("users", vec![make_col("id", "int4")])], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files.len(), 1); assert_eq!(files[0].filename, "users.rs"); } @@ -837,7 +1106,15 @@ mod tests { ], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files.len(), 2); } @@ -852,7 +1129,15 @@ mod tests { }], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files.len(), 1); assert_eq!(files[0].filename, "types.rs"); } @@ -878,7 +1163,15 @@ mod tests { }], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); // Should produce exactly 1 types.rs let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect(); assert_eq!(types_files.len(), 1); @@ -896,7 +1189,15 @@ mod tests { }], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files.len(), 2); // users.rs + types.rs } @@ -906,7 +1207,15 @@ mod tests { tables: vec![make_table("user__data", vec![make_col("id", "int4")])], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files[0].filename, "user_data.rs"); } @@ -916,7 +1225,15 @@ mod tests { tables: vec![make_table("users", vec![make_col("id", "int4")])], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files[0].origin, None); } @@ -931,7 +1248,15 @@ mod tests { }], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files[0].origin, None); } @@ -947,7 +1272,15 @@ mod tests { }], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + true, + TimeCrate::Chrono, + ) + .unwrap(); // struct file should not have super::types:: imports let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap(); assert!(!struct_file.code.contains("super::types::")); @@ -966,7 +1299,15 @@ mod tests { }], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap(); assert!(struct_file.code.contains("super::types::")); } @@ -978,7 +1319,15 @@ mod tests { ..Default::default() }; let derives = vec!["Serialize".to_string()]; - let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &derives, + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert!(files[0].code.contains("Serialize")); } @@ -994,7 +1343,15 @@ mod tests { ..Default::default() }; let derives = vec!["Serialize".to_string()]; - let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &derives, + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert!(files[0].code.contains("Serialize")); } @@ -1006,17 +1363,25 @@ mod tests { tables: vec![make_table("users", vec![make_col("data", "jsonb")])], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &overrides, false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &overrides, + false, + TimeCrate::Chrono, + ) + .unwrap(); assert!(files[0].code.contains("MyJson")); } #[test] fn test_generate_valid_rust_syntax() { let schema = SchemaInfo { - tables: vec![make_table("users", vec![ - make_col("id", "int4"), - make_col("name", "text"), - ])], + tables: vec![make_table( + "users", + vec![make_col("id", "int4"), make_col("name", "text")], + )], enums: vec![EnumInfo { schema_name: "public".to_string(), name: "status".to_string(), @@ -1025,11 +1390,24 @@ mod tests { }], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); for f in &files { // Should be parseable as valid Rust let parse_result = syn::parse_file(&f.code); - assert!(parse_result.is_ok(), "Failed to parse {}: {:?}", f.filename, parse_result.err()); + assert!( + parse_result.is_ok(), + "Failed to parse {}: {:?}", + f.filename, + parse_result.err() + ); } } @@ -1049,7 +1427,15 @@ mod tests { views: vec![make_view("active_users", vec![make_col("id", "int4")])], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files.len(), 1); assert_eq!(files[0].filename, "active_users.rs"); } @@ -1060,7 +1446,15 @@ mod tests { views: vec![make_view("active_users", vec![make_col("id", "int4")])], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files[0].origin, None); } @@ -1071,40 +1465,72 @@ mod tests { views: vec![make_view("active_users", vec![make_col("id", "int4")])], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files.len(), 2); } #[test] fn test_generate_view_valid_rust() { let schema = SchemaInfo { - views: vec![make_view("active_users", vec![ - make_col("id", "int4"), - make_col("name", "text"), - ])], + views: vec![make_view( + "active_users", + vec![make_col("id", "int4"), make_col("name", "text")], + )], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); let parse_result = syn::parse_file(&files[0].code); - assert!(parse_result.is_ok(), "Failed to parse: {:?}", parse_result.err()); + assert!( + parse_result.is_ok(), + "Failed to parse: {:?}", + parse_result.err() + ); } #[test] fn test_generate_view_nullable_column() { let schema = SchemaInfo { - views: vec![make_view("v", vec![ColumnInfo { - name: "email".to_string(), - data_type: "text".to_string(), - udt_name: "text".to_string(), - is_nullable: true, - is_primary_key: false, - ordinal_position: 0, - schema_name: "public".to_string(), - column_default: None, - }])], + views: vec![make_view( + "v", + vec![ColumnInfo { + name: "email".to_string(), + data_type: "text".to_string(), + udt_name: "text".to_string(), + is_nullable: true, + is_primary_key: false, + ordinal_position: 0, + schema_name: "public".to_string(), + udt_schema: None, + column_default: None, + }], + )], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert!(files[0].code.contains("Option")); } @@ -1121,7 +1547,15 @@ mod tests { ], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect(); assert!(filenames.contains(&"users.rs")); assert!(filenames.contains(&"billing_users.rs")); @@ -1140,7 +1574,15 @@ mod tests { ], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect(); assert!(filenames.contains(&"users.rs")); assert!(filenames.contains(&"invoices.rs")); @@ -1155,7 +1597,15 @@ mod tests { ], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files[0].filename, "users.rs"); assert_eq!(files[1].filename, "posts.rs"); } @@ -1167,7 +1617,15 @@ mod tests { views: vec![make_view("active_users", vec![make_col("id", "int4")])], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + true, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files.len(), 2); } @@ -1221,6 +1679,7 @@ mod tests { is_primary_key: false, ordinal_position: 0, schema_name: "public".to_string(), + udt_schema: None, column_default: Some("'idle'::task_status".to_string()), }], }], @@ -1250,6 +1709,7 @@ mod tests { is_primary_key: false, ordinal_position: 0, schema_name: "public".to_string(), + udt_schema: None, column_default: None, }], }], @@ -1279,6 +1739,7 @@ mod tests { is_primary_key: false, ordinal_position: 0, schema_name: "public".to_string(), + udt_schema: None, column_default: Some("'hello'::character varying".to_string()), }], }], @@ -1303,6 +1764,7 @@ mod tests { is_primary_key: false, ordinal_position: 0, schema_name: "public".to_string(), + udt_schema: None, column_default: Some("'idle'::task_status".to_string()), }], }], @@ -1314,7 +1776,15 @@ mod tests { }], ..Default::default() }; - let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = generate( + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); let types_file = files.iter().find(|f| f.filename == "types.rs").unwrap(); assert!(types_file.code.contains("impl Default for TaskStatus")); assert!(types_file.code.contains("Self::Idle")); diff --git a/crates/sqlx_gen/src/codegen/struct_gen.rs b/crates/sqlx_gen/src/codegen/struct_gen.rs index 93b87eb..9d98882 100644 --- a/crates/sqlx_gen/src/codegen/struct_gen.rs +++ b/crates/sqlx_gen/src/codegen/struct_gen.rs @@ -47,20 +47,17 @@ pub fn generate_struct( .columns .iter() .map(|col| { - let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides, time_crate); + let rust_type = + resolve_column_type(col, db_kind, table, schema_info, type_overrides, time_crate); if let Some(imp) = &rust_type.needs_import { imports.insert(imp.clone()); } - let field_name_snake = col.name.to_snake_case(); + let field_name_snake = sanitize_rust_ident(&col.name.to_snake_case()); // If the field name is a Rust keyword, prefix with table name // e.g. column "type" on table "connector" → "connector_type" let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) { - let prefixed = format!( - "{}_{}", - table.name.to_snake_case(), - field_name_snake - ); + let prefixed = format!("{}_{}", table.name.to_snake_case(), field_name_snake); (prefixed, true) } else { let changed = field_name_snake != col.name; @@ -87,12 +84,20 @@ pub fn generate_struct( let has_default = col.column_default.is_some(); let sqlx_gen_attr = if has_pk || has_sql_type || has_default { - let pk_part = if has_pk { quote! { primary_key, } } else { quote! {} }; + let pk_part = if has_pk { + quote! { primary_key, } + } else { + quote! {} + }; let sql_type_part = match &sql_type { Some(t) => quote! { sql_type = #t, }, None => quote! {}, }; - let array_part = if is_sql_array { quote! { is_array, } } else { quote! {} }; + let array_part = if is_sql_array { + quote! { is_array, } + } else { + quote! {} + }; let default_part = match &col.column_default { Some(d) => quote! { column_default = #d, }, None => quote! {}, @@ -125,6 +130,34 @@ pub fn generate_struct( (tokens, imports) } +/// Sanitize a candidate Rust identifier: +/// - replace any character that is not ascii-alphanumeric or '_' with '_' +/// - prefix with '_' if the result starts with a digit +/// - fall back to "_field" if the input is empty +/// +/// Lets sqlx-gen survive columns named `user-id`, `created at`, `123`, etc. +/// — they still need a `#[sqlx(rename = "")]` to roundtrip the DB +/// column, which the caller handles via the `changed` flag. +pub(crate) fn sanitize_rust_ident(name: &str) -> String { + if name.is_empty() { + return "_field".to_string(); + } + let mut out: String = name + .chars() + .map(|c| { + if c.is_ascii_alphanumeric() || c == '_' { + c + } else { + '_' + } + }) + .collect(); + if out.starts_with(|c: char| c.is_ascii_digit()) { + out.insert(0, '_'); + } + out +} + /// Detect if a column uses a custom SQL type (enum or composite) and return the qualified /// SQL type name for casting, plus whether it's an array. /// Returns `(Some("type_name"), true)` for arrays of custom types, @@ -141,7 +174,11 @@ fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option String { let schema = SchemaInfo::default(); - let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), false, TimeCrate::Chrono); - parse_and_format(&tokens) + let (tokens, _) = generate_struct( + table, + DatabaseKind::Postgres, + &schema, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ); + parse_and_format(&tokens).unwrap() } fn gen_with( @@ -228,18 +274,29 @@ mod tests { derives: &[String], overrides: &HashMap, ) -> (String, BTreeSet) { - let (tokens, imports) = generate_struct(table, db, schema, derives, overrides, false, TimeCrate::Chrono); - (parse_and_format(&tokens), imports) + let (tokens, imports) = generate_struct( + table, + db, + schema, + derives, + overrides, + false, + TimeCrate::Chrono, + ); + (parse_and_format(&tokens).unwrap(), imports) } // --- basic structure --- #[test] fn test_simple_table() { - let table = make_table("users", vec![ - make_col("id", "int4", false), - make_col("name", "text", false), - ]); + let table = make_table( + "users", + vec![ + make_col("id", "int4", false), + make_col("name", "text", false), + ], + ); let code = gen(&table); assert!(code.contains("pub id: i32")); assert!(code.contains("pub name: String")); @@ -278,10 +335,10 @@ mod tests { #[test] fn test_mix_nullable() { - let table = make_table("users", vec![ - make_col("id", "int4", false), - make_col("bio", "text", true), - ]); + let table = make_table( + "users", + vec![make_col("id", "int4", false), make_col("bio", "text", true)], + ); let code = gen(&table); assert!(code.contains("pub id: i32")); assert!(code.contains("pub bio: Option")); @@ -361,7 +418,13 @@ mod tests { let table = make_table("users", vec![make_col("id", "int4", false)]); let schema = SchemaInfo::default(); let derives = vec!["Serialize".to_string()]; - let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new()); + let (code, _) = gen_with( + &table, + &schema, + DatabaseKind::Postgres, + &derives, + &HashMap::new(), + ); assert!(code.contains("Serialize")); } @@ -370,7 +433,13 @@ mod tests { let table = make_table("users", vec![make_col("id", "int4", false)]); let schema = SchemaInfo::default(); let derives = vec!["Serialize".to_string(), "Deserialize".to_string()]; - let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new()); + let (_, imports) = gen_with( + &table, + &schema, + DatabaseKind::Postgres, + &derives, + &HashMap::new(), + ); assert!(imports.iter().any(|i| i.contains("serde"))); } @@ -380,7 +449,13 @@ mod tests { fn test_uuid_import() { let table = make_table("users", vec![make_col("id", "uuid", false)]); let schema = SchemaInfo::default(); - let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new()); + let (_, imports) = gen_with( + &table, + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + ); assert!(imports.iter().any(|i| i.contains("uuid::Uuid"))); } @@ -388,7 +463,13 @@ mod tests { fn test_timestamptz_import() { let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]); let schema = SchemaInfo::default(); - let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new()); + let (_, imports) = gen_with( + &table, + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + ); assert!(imports.iter().any(|i| i.contains("chrono"))); } @@ -396,7 +477,13 @@ mod tests { fn test_int4_only_serde_import() { let table = make_table("users", vec![make_col("id", "int4", false)]); let schema = SchemaInfo::default(); - let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new()); + let (_, imports) = gen_with( + &table, + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + ); assert_eq!(imports.len(), 2); assert!(imports.iter().any(|i| i.contains("serde"))); assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen"))); @@ -404,12 +491,21 @@ mod tests { #[test] fn test_multiple_imports_collected() { - let table = make_table("users", vec![ - make_col("id", "uuid", false), - make_col("created_at", "timestamptz", false), - ]); + let table = make_table( + "users", + vec![ + make_col("id", "uuid", false), + make_col("created_at", "timestamptz", false), + ], + ); let schema = SchemaInfo::default(); - let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new()); + let (_, imports) = gen_with( + &table, + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + ); assert!(imports.iter().any(|i| i.contains("uuid"))); assert!(imports.iter().any(|i| i.contains("chrono"))); } @@ -418,16 +514,20 @@ mod tests { #[test] fn test_mysql_enum_column() { - let table = make_table("users", vec![ColumnInfo { - name: "status".to_string(), - data_type: "enum".to_string(), - udt_name: "enum('active','inactive')".to_string(), - is_nullable: false, - is_primary_key: false, - ordinal_position: 0, - schema_name: "test_db".to_string(), - column_default: None, - }]); + let table = make_table( + "users", + vec![ColumnInfo { + name: "status".to_string(), + data_type: "enum".to_string(), + udt_name: "enum('active','inactive')".to_string(), + is_nullable: false, + is_primary_key: false, + ordinal_position: 0, + schema_name: "test_db".to_string(), + udt_schema: None, + column_default: None, + }], + ); let schema = SchemaInfo::default(); let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new()); assert!(code.contains("UsersStatus")); @@ -436,16 +536,20 @@ mod tests { #[test] fn test_mysql_enum_nullable() { - let table = make_table("users", vec![ColumnInfo { - name: "status".to_string(), - data_type: "enum".to_string(), - udt_name: "enum('a','b')".to_string(), - is_nullable: true, - is_primary_key: false, - ordinal_position: 0, - schema_name: "test_db".to_string(), - column_default: None, - }]); + let table = make_table( + "users", + vec![ColumnInfo { + name: "status".to_string(), + data_type: "enum".to_string(), + udt_name: "enum('a','b')".to_string(), + is_nullable: true, + is_primary_key: false, + ordinal_position: 0, + schema_name: "test_db".to_string(), + udt_schema: None, + column_default: None, + }], + ); let schema = SchemaInfo::default(); let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new()); assert!(code.contains("Option")); @@ -467,7 +571,13 @@ mod tests { fn test_type_override_absent() { let table = make_table("users", vec![make_col("data", "jsonb", false)]); let schema = SchemaInfo::default(); - let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new()); + let (code, _) = gen_with( + &table, + &schema, + DatabaseKind::Postgres, + &[], + &HashMap::new(), + ); assert!(code.contains("Value")); } @@ -517,4 +627,50 @@ mod tests { assert!(code.contains("pub name: String")); assert!(!code.contains("sql_type")); } + + // ========== sanitize_rust_ident ========== + + #[test] + fn test_sanitize_replaces_dash() { + assert_eq!(sanitize_rust_ident("user-id"), "user_id"); + } + + #[test] + fn test_sanitize_replaces_space() { + assert_eq!(sanitize_rust_ident("created at"), "created_at"); + } + + #[test] + fn test_sanitize_replaces_dot() { + assert_eq!(sanitize_rust_ident("a.b"), "a_b"); + } + + #[test] + fn test_sanitize_prefixes_leading_digit() { + assert_eq!(sanitize_rust_ident("123abc"), "_123abc"); + } + + #[test] + fn test_sanitize_empty_becomes_placeholder() { + assert_eq!(sanitize_rust_ident(""), "_field"); + } + + #[test] + fn test_sanitize_leaves_valid_ident_unchanged() { + assert_eq!(sanitize_rust_ident("user_id"), "user_id"); + assert_eq!(sanitize_rust_ident("_private"), "_private"); + } + + #[test] + fn test_column_with_dash_generates_valid_rust() { + let table = make_table("users", vec![make_col("user-id", "int4", false)]); + let code = gen(&table); + // Must produce a Rust-legal identifier; renamed back to the original via #[sqlx(rename)] + assert!( + code.contains("pub user_id:") || code.contains("user_id:"), + "expected sanitized identifier, got:\n{}", + code + ); + assert!(code.contains("sqlx(rename = \"user-id\")")); + } } diff --git a/crates/sqlx_gen/src/error.rs b/crates/sqlx_gen/src/error.rs index 15f7cac..03f7f54 100644 --- a/crates/sqlx_gen/src/error.rs +++ b/crates/sqlx_gen/src/error.rs @@ -2,6 +2,19 @@ use std::io; #[derive(Debug, thiserror::Error)] pub enum Error { + #[error("Database connection failed ({redacted_url}): {source}")] + Connection { + redacted_url: String, + #[source] + source: sqlx::Error, + }, + + #[error("Permission denied while introspecting: {detail}. Check the DB user's privileges on information_schema / pg_catalog / sqlite_master.")] + PermissionDenied { detail: String }, + + #[error("Schema or relation not found: {detail}. Check `--schemas` and ensure the database contains the expected tables.")] + SchemaNotFound { detail: String }, + #[error("Database error: {0}")] Database(#[from] sqlx::Error), @@ -13,3 +26,114 @@ pub enum Error { } pub type Result = std::result::Result; + +/// Inspect a [`sqlx::Error`] and, if it carries a SQLSTATE we know how to +/// explain, return a richer [`Error`] variant. Otherwise the input is wrapped +/// in [`Error::Database`] unchanged so callers can keep using `?`. +pub fn contextualize_sqlx_error(err: sqlx::Error) -> Error { + use sqlx::Error as Sx; + let code: Option = match &err { + Sx::Database(db) => db.code().map(|c| c.to_string()), + _ => None, + }; + if let Some(code) = code { + // PG: 42501 insufficient_privilege; MySQL: 42000 / 28000. + // PG: 42P01 undefined_table, 3F000 invalid_schema_name; MySQL: 42S02. + match code.as_str() { + "42501" | "28000" => { + return Error::PermissionDenied { + detail: err.to_string(), + }; + } + "42P01" | "3F000" | "42S02" => { + return Error::SchemaNotFound { + detail: err.to_string(), + }; + } + _ => {} + } + } + Error::Database(err) +} + +/// Redact `user:password@host` → `user:****@host` in a database URL so it can +/// be embedded in error messages and logs without leaking credentials. +pub fn redact_url(url: &str) -> String { + let (scheme, rest) = match url.split_once("://") { + Some(pair) => pair, + None => return url.to_string(), + }; + let (userinfo, host_part) = match rest.split_once('@') { + Some(pair) => pair, + None => return url.to_string(), + }; + let redacted_userinfo = match userinfo.split_once(':') { + Some((user, _pw)) => format!("{}:****", user), + None => userinfo.to_string(), + }; + format!("{}://{}@{}", scheme, redacted_userinfo, host_part) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn redacts_password_in_postgres_url() { + assert_eq!( + redact_url("postgres://alice:s3cret@localhost:5432/db"), + "postgres://alice:****@localhost:5432/db" + ); + } + + #[test] + fn redacts_password_in_mysql_url() { + assert_eq!( + redact_url("mysql://root:hunter2@db:3306/app"), + "mysql://root:****@db:3306/app" + ); + } + + #[test] + fn redacts_password_in_postgresql_url() { + assert_eq!( + redact_url("postgresql://u:p@h/d"), + "postgresql://u:****@h/d" + ); + } + + #[test] + fn leaves_passwordless_sqlite_url_unchanged() { + assert_eq!(redact_url("sqlite:///tmp/test.db"), "sqlite:///tmp/test.db"); + } + + #[test] + fn leaves_no_userinfo_unchanged() { + assert_eq!( + redact_url("postgres://localhost/db"), + "postgres://localhost/db" + ); + } + + #[test] + fn leaves_userinfo_without_password_unchanged() { + assert_eq!( + redact_url("postgres://alice@localhost/db"), + "postgres://alice@localhost/db" + ); + } + + #[test] + fn leaves_non_url_string_unchanged() { + assert_eq!(redact_url("not-a-url"), "not-a-url"); + } + + #[test] + fn contextualize_non_database_error_wraps_unchanged() { + let err = sqlx::Error::PoolTimedOut; + match contextualize_sqlx_error(err) { + Error::Database(_) => {} + other => panic!("expected Database, got {:?}", other), + } + } +} diff --git a/crates/sqlx_gen/src/introspect/mod.rs b/crates/sqlx_gen/src/introspect/mod.rs index ff55bff..b182adc 100644 --- a/crates/sqlx_gen/src/introspect/mod.rs +++ b/crates/sqlx_gen/src/introspect/mod.rs @@ -2,7 +2,7 @@ pub mod mysql; pub mod postgres; pub mod sqlite; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] #[allow(unused)] pub struct ColumnInfo { pub name: String, @@ -10,6 +10,13 @@ pub struct ColumnInfo { pub data_type: String, /// Underlying type name: udt_name (PG), column_type (MySQL), declared type (SQLite) pub udt_name: String, + /// Schema in which `udt_name` is defined. + /// + /// Populated by the Postgres backend (e.g. `auth` for an `auth.role` enum + /// column, `pg_catalog` for builtins). `None` for MySQL/SQLite which have + /// no per-type namespacing. Used to disambiguate enums/composites/domains + /// when two schemas declare a type with the same name. + pub udt_schema: Option, pub is_nullable: bool, pub is_primary_key: bool, pub ordinal_position: i32, diff --git a/crates/sqlx_gen/src/introspect/mysql.rs b/crates/sqlx_gen/src/introspect/mysql.rs index 4685630..ca37638 100644 --- a/crates/sqlx_gen/src/introspect/mysql.rs +++ b/crates/sqlx_gen/src/introspect/mysql.rs @@ -59,7 +59,19 @@ async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result, Vec, Vec, Vec, Vec, Vec, u32, Vec)>(&query); + let mut q = sqlx::query_as::< + _, + ( + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + u32, + Vec, + ), + >(&query); for schema in schemas { q = q.bind(schema); } @@ -69,13 +81,13 @@ async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result = None; for (schema, table, col_name, data_type, column_type, nullable, ordinal, column_key) in rows { - let schema = String::from_utf8(schema).expect("Could not convert schema name from UTF8 bytes"); - let table = String::from_utf8(table).expect("Could not convert schema name from UTF8 bytes"); - let col_name = String::from_utf8(col_name).expect("Could not convert col_name name from UTF8 bytes"); - let data_type = String::from_utf8(data_type).expect("Could not convert data_type name from UTF8 bytes"); - let column_type = String::from_utf8(column_type).expect("Could not convert column_type name from UTF8 bytes"); - let nullable = String::from_utf8(nullable).expect("Could not convert nullable name from UTF8 bytes"); - let column_key = String::from_utf8(column_key).expect("Could not convert column_key name from UTF8 bytes"); + let schema = utf8_field(schema, "TABLE_SCHEMA")?; + let table = utf8_field(table, "TABLE_NAME")?; + let col_name = utf8_field(col_name, "COLUMN_NAME")?; + let data_type = utf8_field(data_type, "DATA_TYPE")?; + let column_type = utf8_field(column_type, "COLUMN_TYPE")?; + let nullable = utf8_field(nullable, "IS_NULLABLE")?; + let column_key = utf8_field(column_key, "COLUMN_KEY")?; let key = (schema.clone(), table.clone()); if current_key.as_ref() != Some(&key) { @@ -86,10 +98,16 @@ async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result Result` metadata field as UTF-8, returning a structured +/// error instead of panicking if the bytes are invalid. +fn utf8_field(bytes: Vec, field: &str) -> Result { + String::from_utf8(bytes).map_err(|_| { + crate::error::Error::Config(format!( + "Database returned non-UTF8 bytes for MySQL information_schema field '{}'. \ + sqlx-gen requires UTF-8 metadata.", + field + )) + }) +} + async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result> { let placeholders: Vec = (0..schemas.len()).map(|_| "?".to_string()).collect(); let query = format!( @@ -143,10 +173,16 @@ async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result Vec let mut view_lookup: HashMap<(&str, &str, &str), Vec> = HashMap::new(); for src in sources { - if let Some(&is_nullable) = - table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str())) - { + if let Some(&is_nullable) = table_lookup.get(&( + src.table_schema.as_str(), + src.table_name.as_str(), + src.column_name.as_str(), + )) { view_lookup .entry((&src.view_schema, &src.view_name, &src.column_name)) .or_default() @@ -276,9 +314,11 @@ fn resolve_view_primary_keys( // Build view column source lookup: (view_schema, view_name, column_name) -> Vec let mut view_lookup: HashMap<(&str, &str, &str), Vec> = HashMap::new(); for src in sources { - if let Some(&is_pk) = - table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str())) - { + if let Some(&is_pk) = table_lookup.get(&( + src.table_schema.as_str(), + src.table_name.as_str(), + src.column_name.as_str(), + )) { view_lookup .entry((&src.view_schema, &src.view_name, &src.column_name)) .or_default() @@ -364,6 +404,7 @@ mod tests { is_primary_key: false, ordinal_position: 0, schema_name: "test_db".to_string(), + udt_schema: None, column_default: None, } } @@ -385,10 +426,7 @@ mod tests { #[test] fn test_parse_with_spaces() { - assert_eq!( - parse_enum_variants("enum( 'a' , 'b' )"), - vec!["a", "b"] - ); + assert_eq!(parse_enum_variants("enum( 'a' , 'b' )"), vec!["a", "b"]); } #[test] @@ -445,10 +483,7 @@ mod tests { #[test] fn test_extract_enum_name_format() { - let tables = vec![make_table( - "users", - vec![make_col("status", "enum('a')")], - )]; + let tables = vec![make_table("users", vec![make_col("status", "enum('a')")])]; let enums = extract_enums(&tables); assert_eq!(enums[0].name, "users_status"); } @@ -492,10 +527,7 @@ mod tests { fn test_extract_non_enum_column_ignored() { let tables = vec![make_table( "users", - vec![ - make_col("id", "int(11)"), - make_col("status", "enum('a')"), - ], + vec![make_col("id", "int(11)"), make_col("status", "enum('a')")], )]; let enums = extract_enums(&tables); assert_eq!(enums.len(), 1); @@ -518,6 +550,7 @@ mod tests { is_primary_key: false, ordinal_position: i as i32, schema_name: schema.to_string(), + udt_schema: None, column_default: None, }) .collect(), @@ -543,6 +576,7 @@ mod tests { is_primary_key: false, ordinal_position: i as i32, schema_name: schema.to_string(), + udt_schema: None, column_default: None, }) .collect(), @@ -622,11 +656,7 @@ mod tests { // ========== resolve_view_primary_keys ========== - fn make_table_with_pk( - schema: &str, - name: &str, - columns: Vec<(&str, bool)>, - ) -> TableInfo { + fn make_table_with_pk(schema: &str, name: &str, columns: Vec<(&str, bool)>) -> TableInfo { TableInfo { schema_name: schema.to_string(), name: name.to_string(), @@ -641,6 +671,7 @@ mod tests { is_primary_key: is_pk, ordinal_position: i as i32, schema_name: schema.to_string(), + udt_schema: None, column_default: None, }) .collect(), @@ -649,7 +680,11 @@ mod tests { #[test] fn test_resolve_pk_column() { - let tables = vec![make_table_with_pk("db", "users", vec![("id", true), ("name", false)])]; + let tables = vec![make_table_with_pk( + "db", + "users", + vec![("id", true), ("name", false)], + )]; let mut views = vec![make_view("db", "my_view", vec!["id", "name"])]; let sources = vec![ make_source("db", "my_view", "db", "users", "id"), diff --git a/crates/sqlx_gen/src/introspect/postgres.rs b/crates/sqlx_gen/src/introspect/postgres.rs index e99b398..de4e9b9 100644 --- a/crates/sqlx_gen/src/introspect/postgres.rs +++ b/crates/sqlx_gen/src/introspect/postgres.rs @@ -39,7 +39,21 @@ pub async fn introspect( } async fn fetch_tables(pool: &PgPool, schemas: &[String]) -> Result> { - let rows = sqlx::query_as::<_, (String, String, String, String, String, String, i32, bool, Option)>( + let rows = sqlx::query_as::< + _, + ( + String, + String, + String, + String, + String, + String, + String, + i32, + bool, + Option, + ), + >( r#" SELECT c.table_schema, @@ -47,6 +61,7 @@ async fn fetch_tables(pool: &PgPool, schemas: &[String]) -> Result Result = Vec::new(); let mut current_key: Option<(String, String)> = None; - for (schema, table, col_name, data_type, udt_name, nullable, ordinal, is_pk, column_default) in rows { + for ( + schema, + table, + col_name, + data_type, + udt_name, + udt_schema, + nullable, + ordinal, + is_pk, + column_default, + ) in rows + { let key = (schema.clone(), table.clone()); if current_key.as_ref() != Some(&key) { current_key = Some(key); @@ -85,10 +112,20 @@ async fn fetch_tables(pool: &PgPool, schemas: &[String]) -> Result Result Result> { - let rows = sqlx::query_as::<_, (String, String, String, String, String, String, i32, Option)>( + let rows = sqlx::query_as::< + _, + ( + String, + String, + String, + String, + String, + String, + String, + i32, + Option, + ), + >( r#" SELECT c.table_schema, @@ -109,6 +159,7 @@ async fn fetch_views(pool: &PgPool, schemas: &[String]) -> Result c.column_name, c.data_type, COALESCE(c.udt_name, c.data_type) as udt_name, + COALESCE(c.udt_schema, '') as udt_schema, c.is_nullable, c.ordinal_position, c.column_default @@ -128,7 +179,18 @@ async fn fetch_views(pool: &PgPool, schemas: &[String]) -> Result let mut views: Vec = Vec::new(); let mut current_key: Option<(String, String)> = None; - for (schema, table, col_name, data_type, udt_name, nullable, ordinal, column_default) in rows { + for ( + schema, + table, + col_name, + data_type, + udt_name, + udt_schema, + nullable, + ordinal, + column_default, + ) in rows + { let key = (schema.clone(), table.clone()); if current_key.as_ref() != Some(&key) { current_key = Some(key); @@ -138,10 +200,20 @@ async fn fetch_views(pool: &PgPool, schemas: &[String]) -> Result columns: Vec::new(), }); } - views.last_mut().unwrap().columns.push(ColumnInfo { + let last = views.last_mut().ok_or_else(|| { + crate::error::Error::Config( + "Internal sqlx-gen bug: views vector empty after push".to_string(), + ) + })?; + last.columns.push(ColumnInfo { name: col_name, data_type, udt_name, + udt_schema: if udt_schema.is_empty() { + None + } else { + Some(udt_schema) + }, is_nullable: nullable == "YES", is_primary_key: false, ordinal_position: ordinal, @@ -192,22 +264,17 @@ async fn fetch_view_column_nullability( Ok(rows .into_iter() .map( - |(view_schema, view_name, source_column_name, source_not_null)| { - ViewColumnNullability { - view_schema, - view_name, - source_column_name, - source_not_null, - } + |(view_schema, view_name, source_column_name, source_not_null)| ViewColumnNullability { + view_schema, + view_name, + source_column_name, + source_not_null, }, ) .collect()) } -fn resolve_view_nullability( - views: &mut [TableInfo], - nullability_info: &[ViewColumnNullability], -) { +fn resolve_view_nullability(views: &mut [TableInfo], nullability_info: &[ViewColumnNullability]) { // Build lookup: (view_schema, view_name, column_name) -> Vec let mut lookup: HashMap<(&str, &str, &str), Vec> = HashMap::new(); for info in nullability_info { @@ -291,10 +358,7 @@ async fn fetch_view_column_primary_keys( .collect()) } -fn resolve_view_primary_keys( - views: &mut [TableInfo], - pk_info: &[ViewColumnPrimaryKey], -) { +fn resolve_view_primary_keys(views: &mut [TableInfo], pk_info: &[ViewColumnPrimaryKey]) { // Build lookup: (view_schema, view_name, column_name) -> Vec let mut lookup: HashMap<(&str, &str, &str), Vec> = HashMap::new(); for info in pk_info { @@ -352,7 +416,12 @@ async fn fetch_enums(pool: &PgPool, schemas: &[String]) -> Result> default_variant: None, }); } - enums.last_mut().unwrap().variants.push(variant); + let last = enums.last_mut().ok_or_else(|| { + crate::error::Error::Config( + "Internal sqlx-gen bug: enums vector empty after push".to_string(), + ) + })?; + last.variants.push(variant); } Ok(enums) @@ -402,10 +471,16 @@ async fn fetch_composite_types( fields: Vec::new(), }); } - composites.last_mut().unwrap().fields.push(ColumnInfo { + let last = composites.last_mut().ok_or_else(|| { + crate::error::Error::Config( + "Internal sqlx-gen bug: composites vector empty after push".to_string(), + ) + })?; + last.fields.push(ColumnInfo { name: field_name, data_type: field_type.clone(), udt_name: field_type, + udt_schema: None, is_nullable: nullable == "YES", is_primary_key: false, ordinal_position: ordinal, @@ -465,6 +540,7 @@ mod tests { is_primary_key: false, ordinal_position: i as i32, schema_name: schema.to_string(), + udt_schema: None, column_default: None, }) .collect(), diff --git a/crates/sqlx_gen/src/introspect/sqlite.rs b/crates/sqlx_gen/src/introspect/sqlite.rs index 1424a28..ea31058 100644 --- a/crates/sqlx_gen/src/introspect/sqlite.rs +++ b/crates/sqlx_gen/src/introspect/sqlite.rs @@ -3,10 +3,10 @@ use std::collections::HashMap; use crate::error::Result; use sqlx::SqlitePool; -use super::{ColumnInfo, SchemaInfo, TableInfo}; +use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo}; pub async fn introspect(pool: &SqlitePool, include_views: bool) -> Result { - let tables = fetch_tables(pool).await?; + let mut tables = fetch_tables(pool).await?; let mut views = if include_views { fetch_views(pool).await? } else { @@ -18,15 +18,124 @@ pub async fn introspect(pool: &SqlitePool, include_views: bool) -> Result__enum`) so the rest of the pipeline +/// treats it like a real enum (with PgHasArrayType skipped for SQLite). +async fn extract_check_enums(pool: &SqlitePool, tables: &mut [TableInfo]) -> Result> { + let mut enums = Vec::new(); + + for table in tables.iter_mut() { + let sql: Option<(Option,)> = + sqlx::query_as("SELECT sql FROM sqlite_master WHERE type = 'table' AND name = ?") + .bind(&table.name) + .fetch_optional(pool) + .await?; + let Some((Some(ddl),)) = sql else { continue }; + + for col in table.columns.iter_mut() { + if let Some(variants) = parse_check_in_variants(&ddl, &col.name) { + if variants.is_empty() { + continue; + } + let enum_name = format!("{}_{}_enum", table.name, col.name); + col.udt_name = enum_name.clone(); + enums.push(EnumInfo { + schema_name: "main".to_string(), + name: enum_name, + variants, + default_variant: None, + }); + } + } + } + + Ok(enums) +} + +/// Parse `CHECK (col IN ('a','b','c'))` for a given column from a SQLite +/// CREATE TABLE statement. Returns the parsed variants in declaration order +/// or `None` if the column has no IN-style CHECK constraint. +fn parse_check_in_variants(ddl: &str, column: &str) -> Option> { + let lower_ddl = ddl.to_ascii_lowercase(); + let lower_col = column.to_ascii_lowercase(); + let mut search_from = 0usize; + + while let Some(rel_check) = lower_ddl[search_from..].find("check") { + let check_pos = search_from + rel_check; + let after_check = &ddl[check_pos + 5..]; + let after_check_lower = &lower_ddl[check_pos + 5..]; + + let open_rel = after_check.find('(')?; + let mut depth = 1i32; + let mut idx = open_rel + 1; + let bytes = after_check.as_bytes(); + while idx < bytes.len() && depth > 0 { + match bytes[idx] { + b'(' => depth += 1, + b')' => depth -= 1, + b'\'' => { + idx += 1; + while idx < bytes.len() && bytes[idx] != b'\'' { + idx += 1; + } + } + _ => {} + } + idx += 1; + } + if depth != 0 { + return None; + } + let body = &after_check[open_rel + 1..idx - 1]; + let body_lower = &after_check_lower[open_rel + 1..idx - 1]; + + search_from = check_pos + 5 + idx; + + if !body_lower.contains(&lower_col) || !body_lower.contains(" in ") { + continue; + } + + if let Some(in_pos) = body_lower.find(" in ") { + let list_start = body[in_pos..].find('(')?; + let list_body = &body[in_pos + list_start + 1..]; + let mut variants = Vec::new(); + let bytes = list_body.as_bytes(); + let mut i = 0; + while i < bytes.len() { + if bytes[i] == b'\'' { + let start = i + 1; + let mut j = start; + while j < bytes.len() && bytes[j] != b'\'' { + j += 1; + } + variants.push(list_body[start..j].to_string()); + i = j + 1; + } else if bytes[i] == b')' { + break; + } else { + i += 1; + } + } + return Some(variants); + } + } + + None +} + async fn fetch_tables(pool: &SqlitePool) -> Result> { let table_names: Vec<(String,)> = sqlx::query_as( "SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name", @@ -49,11 +158,10 @@ async fn fetch_tables(pool: &SqlitePool) -> Result> { } async fn fetch_views(pool: &SqlitePool) -> Result> { - let view_names: Vec<(String,)> = sqlx::query_as( - "SELECT name FROM sqlite_master WHERE type = 'view' ORDER BY name", - ) - .fetch_all(pool) - .await?; + let view_names: Vec<(String,)> = + sqlx::query_as("SELECT name FROM sqlite_master WHERE type = 'view' ORDER BY name") + .fetch_all(pool) + .await?; let mut views = Vec::new(); @@ -83,6 +191,7 @@ async fn fetch_columns(pool: &SqlitePool, table_name: &str) -> Result 0, ordinal_position: cid, @@ -100,7 +209,10 @@ fn resolve_view_nullability(views: &mut [TableInfo], tables: &[TableInfo]) { let mut col_lookup: HashMap<&str, Vec> = HashMap::new(); for table in tables { for col in &table.columns { - col_lookup.entry(&col.name).or_default().push(col.is_nullable); + col_lookup + .entry(&col.name) + .or_default() + .push(col.is_nullable); } } @@ -124,7 +236,10 @@ fn resolve_view_primary_keys(views: &mut [TableInfo], tables: &[TableInfo]) { let mut col_lookup: HashMap<&str, Vec> = HashMap::new(); for table in tables { for col in &table.columns { - col_lookup.entry(&col.name).or_default().push(col.is_primary_key); + col_lookup + .entry(&col.name) + .or_default() + .push(col.is_primary_key); } } @@ -160,6 +275,7 @@ mod tests { is_primary_key: false, ordinal_position: i as i32, schema_name: "main".to_string(), + udt_schema: None, column_default: None, }) .collect(), @@ -181,6 +297,7 @@ mod tests { is_primary_key: false, ordinal_position: i as i32, schema_name: "main".to_string(), + udt_schema: None, column_default: None, }) .collect(), @@ -249,6 +366,7 @@ mod tests { is_primary_key: is_pk, ordinal_position: i as i32, schema_name: "main".to_string(), + udt_schema: None, column_default: None, }) .collect(), @@ -257,7 +375,10 @@ mod tests { #[test] fn test_resolve_pk_unique_match() { - let tables = vec![make_table_with_pk("users", vec![("id", true), ("name", false)])]; + let tables = vec![make_table_with_pk( + "users", + vec![("id", true), ("name", false)], + )]; let mut views = vec![make_view("my_view", vec!["id", "name"])]; resolve_view_primary_keys(&mut views, &tables); assert!(views[0].columns[0].is_primary_key); @@ -290,4 +411,49 @@ mod tests { resolve_view_primary_keys(&mut views, &[]); assert!(!views[0].columns[0].is_primary_key); } + + // ========== parse_check_in_variants ========== + + #[test] + fn test_parse_check_in_simple() { + let ddl = "CREATE TABLE t (id INTEGER PRIMARY KEY, status TEXT CHECK (status IN ('active', 'inactive')) NOT NULL)"; + assert_eq!( + parse_check_in_variants(ddl, "status"), + Some(vec!["active".to_string(), "inactive".to_string()]) + ); + } + + #[test] + fn test_parse_check_in_three_variants() { + let ddl = "CREATE TABLE t (priority TEXT CHECK (priority IN ('low','medium','high')))"; + assert_eq!( + parse_check_in_variants(ddl, "priority"), + Some(vec![ + "low".to_string(), + "medium".to_string(), + "high".to_string() + ]) + ); + } + + #[test] + fn test_parse_check_in_returns_none_for_other_column() { + let ddl = "CREATE TABLE t (status TEXT CHECK (status IN ('a','b')))"; + assert_eq!(parse_check_in_variants(ddl, "other"), None); + } + + #[test] + fn test_parse_check_in_returns_none_without_check() { + let ddl = "CREATE TABLE t (status TEXT)"; + assert_eq!(parse_check_in_variants(ddl, "status"), None); + } + + #[test] + fn test_parse_check_in_case_insensitive_keyword() { + let ddl = "CREATE TABLE t (status TEXT check (Status in ('a','b')))"; + assert_eq!( + parse_check_in_variants(ddl, "status"), + Some(vec!["a".to_string(), "b".to_string()]) + ); + } } diff --git a/crates/sqlx_gen/src/main.rs b/crates/sqlx_gen/src/main.rs index 9ccc74f..1efa426 100644 --- a/crates/sqlx_gen/src/main.rs +++ b/crates/sqlx_gen/src/main.rs @@ -10,10 +10,7 @@ use sqlx_gen::writer; #[tokio::main] async fn main() -> Result<()> { - env_logger::Builder::from_env( - env_logger::Env::default().default_filter_or("info"), - ) - .init(); + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); let cli = Cli::parse(); match cli.command { @@ -27,7 +24,7 @@ async fn main() -> Result<()> { async fn run_entities(args: EntitiesArgs) -> Result<()> { let db_kind = args.db.database_kind()?; - let type_overrides = args.parse_type_overrides(); + let type_overrides = args.parse_type_overrides_checked()?; info!( "Connecting to {} database...", @@ -38,24 +35,47 @@ async fn run_entities(args: EntitiesArgs) -> Result<()> { } ); + let redacted = sqlx_gen::error::redact_url(&args.db.database_url); + let conn_err = |source: sqlx::Error| sqlx_gen::error::Error::Connection { + redacted_url: redacted.clone(), + source, + }; + // If introspection itself returns a sqlx::Error wrapped in our generic + // Database variant, re-classify by SQLSTATE so the user sees an + // actionable message instead of the raw "error returned from database". + let map_introspect_error = |e: sqlx_gen::error::Error| match e { + sqlx_gen::error::Error::Database(inner) => sqlx_gen::error::contextualize_sqlx_error(inner), + other => other, + }; + let mut schema_info = match db_kind { DatabaseKind::Postgres => { - let pool = PgPool::connect(&args.db.database_url).await?; - let info = - introspect::postgres::introspect(&pool, &args.db.schemas, args.views).await?; + let pool = PgPool::connect(&args.db.database_url) + .await + .map_err(conn_err)?; + let info = introspect::postgres::introspect(&pool, &args.db.schemas, args.views) + .await + .map_err(map_introspect_error)?; pool.close().await; info } DatabaseKind::Mysql => { - let pool = MySqlPool::connect(&args.db.database_url).await?; - let info = - introspect::mysql::introspect(&pool, &args.db.schemas, args.views).await?; + let pool = MySqlPool::connect(&args.db.database_url) + .await + .map_err(conn_err)?; + let info = introspect::mysql::introspect(&pool, &args.db.schemas, args.views) + .await + .map_err(map_introspect_error)?; pool.close().await; info } DatabaseKind::Sqlite => { - let pool = SqlitePool::connect(&args.db.database_url).await?; - let info = introspect::sqlite::introspect(&pool, args.views).await?; + let pool = SqlitePool::connect(&args.db.database_url) + .await + .map_err(conn_err)?; + let info = introspect::sqlite::introspect(&pool, args.views) + .await + .map_err(map_introspect_error)?; pool.close().await; info } @@ -83,15 +103,36 @@ async fn run_entities(args: EntitiesArgs) -> Result<()> { schema_info.composite_types.len(), schema_info.domains.len(), ); + if db_kind == DatabaseKind::Postgres { + let needed = sqlx_gen::codegen::required_pg_search_path(&schema_info); + if !needed.is_empty() { + info!( + "Generated types reference non-default schemas: {}. \ + Configure the sqlx pool to include them in search_path, e.g. \ + `SET search_path TO public, {}`", + needed.join(", "), + needed.join(", ") + ); + } + } + if table_count == 0 && view_count == 0 && enum_count == 0 { + warn!( + "No tables, views, or enums found in schemas {:?}. \ + Either the schema is empty or the DB user lacks SELECT on \ + information_schema. Check credentials and `--schemas`.", + args.db.schemas + ); + } - let files = codegen::generate( + let files = codegen::generate_with_domain_style( &schema_info, db_kind, &args.derives, &type_overrides, args.single_file, args.time_crate, - ); + args.domain_style, + )?; writer::write_files(&files, &args.output_dir, args.single_file, args.dry_run)?; @@ -108,7 +149,10 @@ fn run_crud(args: CrudArgs) -> Result<()> { let entities_module = args.resolve_entities_module()?; // Validate that the resolved module is a Rust module path, not a file path - if entities_module.contains('/') || entities_module.contains('\\') || entities_module.ends_with(".rs") { + if entities_module.contains('/') + || entities_module.contains('\\') + || entities_module.ends_with(".rs") + { return Err(sqlx_gen::error::Error::Config(format!( "--entities-module must be a Rust module path (e.g. \"crate::models::users\"), got \"{}\"", entities_module @@ -144,7 +188,7 @@ fn run_crud(args: CrudArgs) -> Result<()> { ); let tab_spaces = codegen::detect_tab_spaces(&args.output_dir); - let code = codegen::format_tokens_with_imports_and_tab_spaces(&tokens, &imports, tab_spaces); + let code = codegen::format_tokens_with_imports_and_tab_spaces(&tokens, &imports, tab_spaces)?; if args.dry_run { println!("{}", code); @@ -160,10 +204,7 @@ fn run_crud(args: CrudArgs) -> Result<()> { let filename = format!("{}_repository.rs", normalized); let file_path = args.output_dir.join(&filename); - let content = format!( - "// Auto-generated by sqlx-gen. Do not edit.\n\n{}", - code - ); + let content = format!("// Auto-generated by sqlx-gen. Do not edit.\n\n{}", code); std::fs::write(&file_path, &content)?; info!("Wrote {}", file_path.display()); let edition = detect_edition(&args.output_dir); diff --git a/crates/sqlx_gen/src/typemap/mod.rs b/crates/sqlx_gen/src/typemap/mod.rs index 4013c7d..55fda8d 100644 --- a/crates/sqlx_gen/src/typemap/mod.rs +++ b/crates/sqlx_gen/src/typemap/mod.rs @@ -54,11 +54,20 @@ pub fn map_column( // Check type overrides first if let Some(override_type) = overrides.get(&col.udt_name) { let rt = RustType::simple(override_type); - return if col.is_nullable { rt.wrap_option() } else { rt }; + return if col.is_nullable { + rt.wrap_option() + } else { + rt + }; } let base = match db_kind { - DatabaseKind::Postgres => postgres::map_type(&col.udt_name, schema_info, time_crate), + DatabaseKind::Postgres => postgres::map_type_qualified( + &col.udt_name, + col.udt_schema.as_deref(), + schema_info, + time_crate, + ), DatabaseKind::Mysql => mysql::map_type(&col.data_type, &col.udt_name, time_crate), DatabaseKind::Sqlite => sqlite::map_type(&col.udt_name, time_crate), }; @@ -85,6 +94,7 @@ mod tests { is_primary_key: false, ordinal_position: 0, schema_name: "public".to_string(), + udt_schema: None, column_default: None, } } @@ -175,7 +185,13 @@ mod tests { let schema = SchemaInfo::default(); let mut overrides = HashMap::new(); overrides.insert("uuid".to_string(), "MyUuid".to_string()); - let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides, TimeCrate::Chrono); + let rt = map_column( + &col, + DatabaseKind::Postgres, + &schema, + &overrides, + TimeCrate::Chrono, + ); assert_eq!(rt.path, "MyUuid"); assert!(rt.needs_import.is_none()); } @@ -186,7 +202,13 @@ mod tests { let schema = SchemaInfo::default(); let mut overrides = HashMap::new(); overrides.insert("uuid".to_string(), "MyUuid".to_string()); - let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides, TimeCrate::Chrono); + let rt = map_column( + &col, + DatabaseKind::Postgres, + &schema, + &overrides, + TimeCrate::Chrono, + ); assert_eq!(rt.path, "Option"); } @@ -195,7 +217,13 @@ mod tests { let col = make_col("int4", "integer", false); let schema = SchemaInfo::default(); let overrides = HashMap::new(); - let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides, TimeCrate::Chrono); + let rt = map_column( + &col, + DatabaseKind::Postgres, + &schema, + &overrides, + TimeCrate::Chrono, + ); assert_eq!(rt.path, "i32"); } @@ -204,8 +232,13 @@ mod tests { let col = make_col("int4", "integer", true); let schema = SchemaInfo::default(); let overrides = HashMap::new(); - let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides, TimeCrate::Chrono); + let rt = map_column( + &col, + DatabaseKind::Postgres, + &schema, + &overrides, + TimeCrate::Chrono, + ); assert_eq!(rt.path, "Option"); } } - diff --git a/crates/sqlx_gen/src/typemap/mysql.rs b/crates/sqlx_gen/src/typemap/mysql.rs index fe6909a..8e7fb5b 100644 --- a/crates/sqlx_gen/src/typemap/mysql.rs +++ b/crates/sqlx_gen/src/typemap/mysql.rs @@ -49,9 +49,7 @@ pub fn map_type(data_type: &str, column_type: &str, time_crate: TimeCrate) -> Ru } "float" => RustType::simple("f32"), "double" => RustType::simple("f64"), - "decimal" | "numeric" => { - RustType::with_import("Decimal", "use rust_decimal::Decimal;") - } + "decimal" | "numeric" => RustType::with_import("Decimal", "use rust_decimal::Decimal;"), "varchar" | "char" | "text" | "tinytext" | "mediumtext" | "longtext" | "enum" | "set" => { RustType::simple("String") } @@ -67,16 +65,30 @@ pub fn map_type(data_type: &str, column_type: &str, time_crate: TimeCrate) -> Ru TimeCrate::Time => RustType::with_import("Time", "use time::Time;"), }, "datetime" => match time_crate { - TimeCrate::Chrono => RustType::with_import("NaiveDateTime", "use chrono::NaiveDateTime;"), - TimeCrate::Time => RustType::with_import("PrimitiveDateTime", "use time::PrimitiveDateTime;"), + TimeCrate::Chrono => { + RustType::with_import("NaiveDateTime", "use chrono::NaiveDateTime;") + } + TimeCrate::Time => { + RustType::with_import("PrimitiveDateTime", "use time::PrimitiveDateTime;") + } }, "timestamp" => match time_crate { - TimeCrate::Chrono => RustType::with_import("DateTime", "use chrono::{DateTime, Utc};"), + TimeCrate::Chrono => { + RustType::with_import("DateTime", "use chrono::{DateTime, Utc};") + } TimeCrate::Time => RustType::with_import("OffsetDateTime", "use time::OffsetDateTime;"), }, "json" => RustType::with_import("Value", "use serde_json::Value;"), "year" => RustType::simple("i16"), - "bit" => RustType::simple("Vec"), + "bit" => { + // BIT(1) is the idiomatic MySQL boolean. Treat anything wider as raw bytes. + if ct == "bit(1)" { + RustType::simple("bool") + } else { + RustType::simple("Vec") + } + } + "boolean" | "bool" => RustType::simple("bool"), _ => RustType::simple("String"), } } @@ -96,7 +108,10 @@ mod tests { #[test] fn test_tinyint1_is_bool() { - assert_eq!(map_type("tinyint", "tinyint(1)", TimeCrate::Chrono).path, "bool"); + assert_eq!( + map_type("tinyint", "tinyint(1)", TimeCrate::Chrono).path, + "bool" + ); } #[test] @@ -106,29 +121,44 @@ mod tests { #[test] fn test_tinyint_unsigned() { - assert_eq!(map_type("tinyint", "tinyint unsigned", TimeCrate::Chrono).path, "u8"); + assert_eq!( + map_type("tinyint", "tinyint unsigned", TimeCrate::Chrono).path, + "u8" + ); } #[test] fn test_tinyint3_signed() { - assert_eq!(map_type("tinyint", "tinyint(3)", TimeCrate::Chrono).path, "i8"); + assert_eq!( + map_type("tinyint", "tinyint(3)", TimeCrate::Chrono).path, + "i8" + ); } #[test] fn test_tinyint3_unsigned() { - assert_eq!(map_type("tinyint", "tinyint(3) unsigned", TimeCrate::Chrono).path, "u8"); + assert_eq!( + map_type("tinyint", "tinyint(3) unsigned", TimeCrate::Chrono).path, + "u8" + ); } // --- smallint --- #[test] fn test_smallint_signed() { - assert_eq!(map_type("smallint", "smallint", TimeCrate::Chrono).path, "i16"); + assert_eq!( + map_type("smallint", "smallint", TimeCrate::Chrono).path, + "i16" + ); } #[test] fn test_smallint_unsigned() { - assert_eq!(map_type("smallint", "smallint unsigned", TimeCrate::Chrono).path, "u16"); + assert_eq!( + map_type("smallint", "smallint unsigned", TimeCrate::Chrono).path, + "u16" + ); } // --- int/mediumint --- @@ -140,17 +170,26 @@ mod tests { #[test] fn test_int_unsigned() { - assert_eq!(map_type("int", "int unsigned", TimeCrate::Chrono).path, "u32"); + assert_eq!( + map_type("int", "int unsigned", TimeCrate::Chrono).path, + "u32" + ); } #[test] fn test_mediumint_signed() { - assert_eq!(map_type("mediumint", "mediumint", TimeCrate::Chrono).path, "i32"); + assert_eq!( + map_type("mediumint", "mediumint", TimeCrate::Chrono).path, + "i32" + ); } #[test] fn test_mediumint_unsigned() { - assert_eq!(map_type("mediumint", "mediumint unsigned", TimeCrate::Chrono).path, "u32"); + assert_eq!( + map_type("mediumint", "mediumint unsigned", TimeCrate::Chrono).path, + "u32" + ); } #[test] @@ -160,7 +199,10 @@ mod tests { #[test] fn test_int11_unsigned() { - assert_eq!(map_type("int", "int(11) unsigned", TimeCrate::Chrono).path, "u32"); + assert_eq!( + map_type("int", "int(11) unsigned", TimeCrate::Chrono).path, + "u32" + ); } // --- bigint --- @@ -172,12 +214,18 @@ mod tests { #[test] fn test_bigint_unsigned() { - assert_eq!(map_type("bigint", "bigint unsigned", TimeCrate::Chrono).path, "u64"); + assert_eq!( + map_type("bigint", "bigint unsigned", TimeCrate::Chrono).path, + "u64" + ); } #[test] fn test_bigint20_signed() { - assert_eq!(map_type("bigint", "bigint(20)", TimeCrate::Chrono).path, "i64"); + assert_eq!( + map_type("bigint", "bigint(20)", TimeCrate::Chrono).path, + "i64" + ); } // --- floats --- @@ -212,12 +260,18 @@ mod tests { #[test] fn test_varchar() { - assert_eq!(map_type("varchar", "varchar(255)", TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("varchar", "varchar(255)", TimeCrate::Chrono).path, + "String" + ); } #[test] fn test_char() { - assert_eq!(map_type("char", "char(1)", TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("char", "char(1)", TimeCrate::Chrono).path, + "String" + ); } #[test] @@ -227,34 +281,52 @@ mod tests { #[test] fn test_tinytext() { - assert_eq!(map_type("tinytext", "tinytext", TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("tinytext", "tinytext", TimeCrate::Chrono).path, + "String" + ); } #[test] fn test_mediumtext() { - assert_eq!(map_type("mediumtext", "mediumtext", TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("mediumtext", "mediumtext", TimeCrate::Chrono).path, + "String" + ); } #[test] fn test_longtext() { - assert_eq!(map_type("longtext", "longtext", TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("longtext", "longtext", TimeCrate::Chrono).path, + "String" + ); } #[test] fn test_set() { - assert_eq!(map_type("set", "set('a','b')", TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("set", "set('a','b')", TimeCrate::Chrono).path, + "String" + ); } // --- binary --- #[test] fn test_binary() { - assert_eq!(map_type("binary", "binary(16)", TimeCrate::Chrono).path, "Vec"); + assert_eq!( + map_type("binary", "binary(16)", TimeCrate::Chrono).path, + "Vec" + ); } #[test] fn test_varbinary() { - assert_eq!(map_type("varbinary", "varbinary(255)", TimeCrate::Chrono).path, "Vec"); + assert_eq!( + map_type("varbinary", "varbinary(255)", TimeCrate::Chrono).path, + "Vec" + ); } #[test] @@ -264,7 +336,10 @@ mod tests { #[test] fn test_tinyblob() { - assert_eq!(map_type("tinyblob", "tinyblob", TimeCrate::Chrono).path, "Vec"); + assert_eq!( + map_type("tinyblob", "tinyblob", TimeCrate::Chrono).path, + "Vec" + ); } // --- dates --- @@ -312,15 +387,31 @@ mod tests { } #[test] - fn test_bit() { - assert_eq!(map_type("bit", "bit(1)", TimeCrate::Chrono).path, "Vec"); + fn test_bit1_is_bool() { + assert_eq!(map_type("bit", "bit(1)", TimeCrate::Chrono).path, "bool"); + } + + #[test] + fn test_bit8_is_bytes() { + assert_eq!(map_type("bit", "bit(8)", TimeCrate::Chrono).path, "Vec"); + } + + #[test] + fn test_boolean_alias_is_bool() { + assert_eq!( + map_type("boolean", "boolean", TimeCrate::Chrono).path, + "bool" + ); } // --- enum placeholder --- #[test] fn test_enum_placeholder() { - assert_eq!(map_type("enum", "enum('a','b','c')", TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("enum", "enum('a','b','c')", TimeCrate::Chrono).path, + "String" + ); } // --- case insensitive --- @@ -332,14 +423,20 @@ mod tests { #[test] fn test_case_insensitive_tinyint1() { - assert_eq!(map_type("TINYINT", "TINYINT(1)", TimeCrate::Chrono).path, "bool"); + assert_eq!( + map_type("TINYINT", "TINYINT(1)", TimeCrate::Chrono).path, + "bool" + ); } // --- fallback --- #[test] fn test_geometry_fallback() { - assert_eq!(map_type("geometry", "geometry", TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("geometry", "geometry", TimeCrate::Chrono).path, + "String" + ); } #[test] @@ -356,7 +453,10 @@ mod tests { #[test] fn test_resolve_enum_user_roles_role_type() { - assert_eq!(resolve_enum_type("user_roles", "role_type"), "UserRolesRoleType"); + assert_eq!( + resolve_enum_type("user_roles", "role_type"), + "UserRolesRoleType" + ); } #[test] @@ -375,14 +475,22 @@ mod tests { fn test_timestamp_time_crate() { let rt = map_type("timestamp", "timestamp", TimeCrate::Time); assert_eq!(rt.path, "OffsetDateTime"); - assert!(rt.needs_import.as_ref().unwrap().contains("time::OffsetDateTime")); + assert!(rt + .needs_import + .as_ref() + .unwrap() + .contains("time::OffsetDateTime")); } #[test] fn test_datetime_time_crate() { let rt = map_type("datetime", "datetime", TimeCrate::Time); assert_eq!(rt.path, "PrimitiveDateTime"); - assert!(rt.needs_import.as_ref().unwrap().contains("time::PrimitiveDateTime")); + assert!(rt + .needs_import + .as_ref() + .unwrap() + .contains("time::PrimitiveDateTime")); } #[test] diff --git a/crates/sqlx_gen/src/typemap/postgres.rs b/crates/sqlx_gen/src/typemap/postgres.rs index 0b9269e..8e370b1 100644 --- a/crates/sqlx_gen/src/typemap/postgres.rs +++ b/crates/sqlx_gen/src/typemap/postgres.rs @@ -1,5 +1,3 @@ -use heck::ToUpperCamelCase; - use super::RustType; use crate::cli::TimeCrate; use crate::introspect::SchemaInfo; @@ -10,49 +8,110 @@ pub fn is_builtin(udt_name: &str) -> bool { matches!( udt_name, "bool" - | "int2" | "smallint" | "smallserial" - | "int4" | "int" | "integer" | "serial" - | "int8" | "bigint" | "bigserial" - | "float4" | "real" - | "float8" | "double precision" - | "numeric" | "decimal" - | "varchar" | "text" | "bpchar" | "char" | "name" | "citext" + | "int2" + | "smallint" + | "smallserial" + | "int4" + | "int" + | "integer" + | "serial" + | "int8" + | "bigint" + | "bigserial" + | "float4" + | "real" + | "float8" + | "double precision" + | "numeric" + | "decimal" + | "varchar" + | "text" + | "bpchar" + | "char" + | "name" + | "citext" | "bytea" - | "timestamp" | "timestamp without time zone" - | "timestamptz" | "timestamp with time zone" + | "timestamp" + | "timestamp without time zone" + | "timestamptz" + | "timestamp with time zone" | "date" - | "time" | "time without time zone" - | "timetz" | "time with time zone" + | "time" + | "time without time zone" + | "timetz" + | "time with time zone" | "uuid" - | "json" | "jsonb" - | "inet" | "cidr" + | "json" + | "jsonb" + | "inet" + | "cidr" + | "interval" | "oid" ) } pub fn map_type(udt_name: &str, schema_info: &SchemaInfo, time_crate: TimeCrate) -> RustType { - // Handle array types (prefixed with '_' in PG) + map_type_qualified(udt_name, None, schema_info, time_crate) +} + +// Shared with the rest of codegen — see codegen::rust_type_name_for. +use crate::codegen::rust_type_name_for as rust_type_name_inner; + +fn rust_type_name(schema: &str, name: &str, schema_info: &SchemaInfo) -> String { + rust_type_name_inner(schema_info, schema, name) +} + +/// Map a PG type name to a Rust type, respecting `udt_schema` when present so +/// that two schemas declaring the same name (e.g. `auth.role` vs +/// `billing.role`) resolve to distinct Rust idents. +pub fn map_type_qualified( + udt_name: &str, + udt_schema: Option<&str>, + schema_info: &SchemaInfo, + time_crate: TimeCrate, +) -> RustType { + // Handle array types: PG's information_schema may report them either as + // `_int4` (information_schema.columns.udt_name) or `integer[]` + // (pg_catalog.format_type). Both should produce Vec. if let Some(inner) = udt_name.strip_prefix('_') { - let inner_type = map_type(inner, schema_info, time_crate); + let inner_type = map_type_qualified(inner, udt_schema, schema_info, time_crate); + return inner_type.wrap_vec(); + } + if let Some(inner) = udt_name.strip_suffix("[]") { + let inner_type = map_type_qualified(inner.trim(), udt_schema, schema_info, time_crate); return inner_type.wrap_vec(); } - // Check if it's a known enum - if schema_info.enums.iter().any(|e| e.name == udt_name) { - let name = udt_name.to_upper_camel_case(); + // Schema-aware enum lookup. When udt_schema is provided we restrict to + // exact (schema, name) matches first; otherwise we fall back to the + // first name match so that legacy callers (and synthetic test fixtures) + // keep working. + let enum_match = schema_info + .enums + .iter() + .find(|e| e.name == udt_name && udt_schema.map(|s| s == e.schema_name).unwrap_or(true)); + if let Some(e) = enum_match { + let name = rust_type_name(&e.schema_name, &e.name, schema_info); return RustType::with_import(&name, &format!("use super::types::{};", name)); } - // Check if it's a known composite type - if schema_info.composite_types.iter().any(|c| c.name == udt_name) { - let name = udt_name.to_upper_camel_case(); + let composite_match = schema_info + .composite_types + .iter() + .find(|c| c.name == udt_name && udt_schema.map(|s| s == c.schema_name).unwrap_or(true)); + if let Some(c) = composite_match { + let name = rust_type_name(&c.schema_name, &c.name, schema_info); return RustType::with_import(&name, &format!("use super::types::{};", name)); } - // Check if it's a known domain - if let Some(domain) = schema_info.domains.iter().find(|d| d.name == udt_name) { - // Map to the domain's base type - return map_type(&domain.base_type, schema_info, time_crate); + let domain_match = schema_info + .domains + .iter() + .find(|d| d.name == udt_name && udt_schema.map(|s| s == d.schema_name).unwrap_or(true)); + if let Some(domain) = domain_match { + // Map to the domain's base type — base type lives in pg_catalog so + // schema is irrelevant for the recursive lookup. + return map_type_qualified(&domain.base_type, None, schema_info, time_crate); } match udt_name { @@ -62,17 +121,21 @@ pub fn map_type(udt_name: &str, schema_info: &SchemaInfo, time_crate: TimeCrate) "int8" | "bigint" | "bigserial" => RustType::simple("i64"), "float4" | "real" => RustType::simple("f32"), "float8" | "double precision" => RustType::simple("f64"), - "numeric" | "decimal" => { - RustType::with_import("Decimal", "use rust_decimal::Decimal;") - } + "numeric" | "decimal" => RustType::with_import("Decimal", "use rust_decimal::Decimal;"), "varchar" | "text" | "bpchar" | "char" | "name" | "citext" => RustType::simple("String"), "bytea" => RustType::simple("Vec"), "timestamp" | "timestamp without time zone" => match time_crate { - TimeCrate::Chrono => RustType::with_import("NaiveDateTime", "use chrono::NaiveDateTime;"), - TimeCrate::Time => RustType::with_import("PrimitiveDateTime", "use time::PrimitiveDateTime;"), + TimeCrate::Chrono => { + RustType::with_import("NaiveDateTime", "use chrono::NaiveDateTime;") + } + TimeCrate::Time => { + RustType::with_import("PrimitiveDateTime", "use time::PrimitiveDateTime;") + } }, "timestamptz" | "timestamp with time zone" => match time_crate { - TimeCrate::Chrono => RustType::with_import("DateTime", "use chrono::{DateTime, Utc};"), + TimeCrate::Chrono => { + RustType::with_import("DateTime", "use chrono::{DateTime, Utc};") + } TimeCrate::Time => RustType::with_import("OffsetDateTime", "use time::OffsetDateTime;"), }, "date" => match time_crate { @@ -88,12 +151,9 @@ pub fn map_type(udt_name: &str, schema_info: &SchemaInfo, time_crate: TimeCrate) TimeCrate::Time => RustType::with_import("Time", "use time::Time;"), }, "uuid" => RustType::with_import("Uuid", "use uuid::Uuid;"), - "json" | "jsonb" => { - RustType::with_import("Value", "use serde_json::Value;") - } - "inet" | "cidr" => { - RustType::with_import("IpNetwork", "use ipnetwork::IpNetwork;") - } + "json" | "jsonb" => RustType::with_import("Value", "use serde_json::Value;"), + "inet" | "cidr" => RustType::with_import("IpNetwork", "use ipnetwork::IpNetwork;"), + "interval" => RustType::with_import("PgInterval", "use sqlx::postgres::types::PgInterval;"), "oid" => RustType::simple("u32"), _ => RustType::simple("String"), // fallback } @@ -147,72 +207,114 @@ mod tests { #[test] fn test_bool() { - assert_eq!(map_type("bool", &empty_schema(), TimeCrate::Chrono).path, "bool"); + assert_eq!( + map_type("bool", &empty_schema(), TimeCrate::Chrono).path, + "bool" + ); } #[test] fn test_int2() { - assert_eq!(map_type("int2", &empty_schema(), TimeCrate::Chrono).path, "i16"); + assert_eq!( + map_type("int2", &empty_schema(), TimeCrate::Chrono).path, + "i16" + ); } #[test] fn test_smallint() { - assert_eq!(map_type("smallint", &empty_schema(), TimeCrate::Chrono).path, "i16"); + assert_eq!( + map_type("smallint", &empty_schema(), TimeCrate::Chrono).path, + "i16" + ); } #[test] fn test_smallserial() { - assert_eq!(map_type("smallserial", &empty_schema(), TimeCrate::Chrono).path, "i16"); + assert_eq!( + map_type("smallserial", &empty_schema(), TimeCrate::Chrono).path, + "i16" + ); } #[test] fn test_int4() { - assert_eq!(map_type("int4", &empty_schema(), TimeCrate::Chrono).path, "i32"); + assert_eq!( + map_type("int4", &empty_schema(), TimeCrate::Chrono).path, + "i32" + ); } #[test] fn test_integer() { - assert_eq!(map_type("integer", &empty_schema(), TimeCrate::Chrono).path, "i32"); + assert_eq!( + map_type("integer", &empty_schema(), TimeCrate::Chrono).path, + "i32" + ); } #[test] fn test_serial() { - assert_eq!(map_type("serial", &empty_schema(), TimeCrate::Chrono).path, "i32"); + assert_eq!( + map_type("serial", &empty_schema(), TimeCrate::Chrono).path, + "i32" + ); } #[test] fn test_int8() { - assert_eq!(map_type("int8", &empty_schema(), TimeCrate::Chrono).path, "i64"); + assert_eq!( + map_type("int8", &empty_schema(), TimeCrate::Chrono).path, + "i64" + ); } #[test] fn test_bigint() { - assert_eq!(map_type("bigint", &empty_schema(), TimeCrate::Chrono).path, "i64"); + assert_eq!( + map_type("bigint", &empty_schema(), TimeCrate::Chrono).path, + "i64" + ); } #[test] fn test_bigserial() { - assert_eq!(map_type("bigserial", &empty_schema(), TimeCrate::Chrono).path, "i64"); + assert_eq!( + map_type("bigserial", &empty_schema(), TimeCrate::Chrono).path, + "i64" + ); } #[test] fn test_float4() { - assert_eq!(map_type("float4", &empty_schema(), TimeCrate::Chrono).path, "f32"); + assert_eq!( + map_type("float4", &empty_schema(), TimeCrate::Chrono).path, + "f32" + ); } #[test] fn test_real() { - assert_eq!(map_type("real", &empty_schema(), TimeCrate::Chrono).path, "f32"); + assert_eq!( + map_type("real", &empty_schema(), TimeCrate::Chrono).path, + "f32" + ); } #[test] fn test_float8() { - assert_eq!(map_type("float8", &empty_schema(), TimeCrate::Chrono).path, "f64"); + assert_eq!( + map_type("float8", &empty_schema(), TimeCrate::Chrono).path, + "f64" + ); } #[test] fn test_double_precision() { - assert_eq!(map_type("double precision", &empty_schema(), TimeCrate::Chrono).path, "f64"); + assert_eq!( + map_type("double precision", &empty_schema(), TimeCrate::Chrono).path, + "f64" + ); } #[test] @@ -230,32 +332,50 @@ mod tests { #[test] fn test_varchar() { - assert_eq!(map_type("varchar", &empty_schema(), TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("varchar", &empty_schema(), TimeCrate::Chrono).path, + "String" + ); } #[test] fn test_text() { - assert_eq!(map_type("text", &empty_schema(), TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("text", &empty_schema(), TimeCrate::Chrono).path, + "String" + ); } #[test] fn test_bpchar() { - assert_eq!(map_type("bpchar", &empty_schema(), TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("bpchar", &empty_schema(), TimeCrate::Chrono).path, + "String" + ); } #[test] fn test_citext() { - assert_eq!(map_type("citext", &empty_schema(), TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("citext", &empty_schema(), TimeCrate::Chrono).path, + "String" + ); } #[test] fn test_name() { - assert_eq!(map_type("name", &empty_schema(), TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("name", &empty_schema(), TimeCrate::Chrono).path, + "String" + ); } #[test] fn test_bytea() { - assert_eq!(map_type("bytea", &empty_schema(), TimeCrate::Chrono).path, "Vec"); + assert_eq!( + map_type("bytea", &empty_schema(), TimeCrate::Chrono).path, + "Vec" + ); } #[test] @@ -325,19 +445,51 @@ mod tests { #[test] fn test_oid() { - assert_eq!(map_type("oid", &empty_schema(), TimeCrate::Chrono).path, "u32"); + assert_eq!( + map_type("oid", &empty_schema(), TimeCrate::Chrono).path, + "u32" + ); + } + + #[test] + fn test_interval_uses_pg_interval() { + let rt = map_type("interval", &empty_schema(), TimeCrate::Chrono); + assert_eq!(rt.path, "PgInterval"); + assert!(rt.needs_import.as_ref().unwrap().contains("PgInterval")); } // --- arrays --- #[test] fn test_array_int4() { - assert_eq!(map_type("_int4", &empty_schema(), TimeCrate::Chrono).path, "Vec"); + assert_eq!( + map_type("_int4", &empty_schema(), TimeCrate::Chrono).path, + "Vec" + ); + } + + #[test] + fn test_array_bracket_notation() { + assert_eq!( + map_type("integer[]", &empty_schema(), TimeCrate::Chrono).path, + "Vec" + ); + } + + #[test] + fn test_array_bracket_text() { + assert_eq!( + map_type("text[]", &empty_schema(), TimeCrate::Chrono).path, + "Vec" + ); } #[test] fn test_array_text() { - assert_eq!(map_type("_text", &empty_schema(), TimeCrate::Chrono).path, "Vec"); + assert_eq!( + map_type("_text", &empty_schema(), TimeCrate::Chrono).path, + "Vec" + ); } #[test] @@ -349,7 +501,10 @@ mod tests { #[test] fn test_array_bool() { - assert_eq!(map_type("_bool", &empty_schema(), TimeCrate::Chrono).path, "Vec"); + assert_eq!( + map_type("_bool", &empty_schema(), TimeCrate::Chrono).path, + "Vec" + ); } #[test] @@ -361,7 +516,10 @@ mod tests { #[test] fn test_array_bytea() { - assert_eq!(map_type("_bytea", &empty_schema(), TimeCrate::Chrono).path, "Vec>"); + assert_eq!( + map_type("_bytea", &empty_schema(), TimeCrate::Chrono).path, + "Vec>" + ); } // --- enums/composites/domains --- @@ -371,7 +529,11 @@ mod tests { let schema = schema_with_enum("status"); let rt = map_type("status", &schema, TimeCrate::Chrono); assert_eq!(rt.path, "Status"); - assert!(rt.needs_import.as_ref().unwrap().contains("super::types::Status")); + assert!(rt + .needs_import + .as_ref() + .unwrap() + .contains("super::types::Status")); } #[test] @@ -386,7 +548,11 @@ mod tests { let schema = schema_with_composite("address"); let rt = map_type("address", &schema, TimeCrate::Chrono); assert_eq!(rt.path, "Address"); - assert!(rt.needs_import.as_ref().unwrap().contains("super::types::Address")); + assert!(rt + .needs_import + .as_ref() + .unwrap() + .contains("super::types::Address")); } #[test] @@ -439,12 +605,18 @@ mod tests { #[test] fn test_geometry_fallback() { - assert_eq!(map_type("geometry", &empty_schema(), TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("geometry", &empty_schema(), TimeCrate::Chrono).path, + "String" + ); } #[test] fn test_hstore_fallback() { - assert_eq!(map_type("hstore", &empty_schema(), TimeCrate::Chrono).path, "String"); + assert_eq!( + map_type("hstore", &empty_schema(), TimeCrate::Chrono).path, + "String" + ); } // --- time crate --- @@ -453,14 +625,22 @@ mod tests { fn test_timestamptz_time_crate() { let rt = map_type("timestamptz", &empty_schema(), TimeCrate::Time); assert_eq!(rt.path, "OffsetDateTime"); - assert!(rt.needs_import.as_ref().unwrap().contains("time::OffsetDateTime")); + assert!(rt + .needs_import + .as_ref() + .unwrap() + .contains("time::OffsetDateTime")); } #[test] fn test_timestamp_time_crate() { let rt = map_type("timestamp", &empty_schema(), TimeCrate::Time); assert_eq!(rt.path, "PrimitiveDateTime"); - assert!(rt.needs_import.as_ref().unwrap().contains("time::PrimitiveDateTime")); + assert!(rt + .needs_import + .as_ref() + .unwrap() + .contains("time::PrimitiveDateTime")); } #[test] diff --git a/crates/sqlx_gen/src/typemap/sqlite.rs b/crates/sqlx_gen/src/typemap/sqlite.rs index 38cafe2..da2f301 100644 --- a/crates/sqlx_gen/src/typemap/sqlite.rs +++ b/crates/sqlx_gen/src/typemap/sqlite.rs @@ -21,8 +21,12 @@ pub fn map_type(declared_type: &str, time_crate: TimeCrate) -> RustType { } if upper.contains("TIMESTAMP") || upper.contains("DATETIME") { return match time_crate { - TimeCrate::Chrono => RustType::with_import("NaiveDateTime", "use chrono::NaiveDateTime;"), - TimeCrate::Time => RustType::with_import("PrimitiveDateTime", "use time::PrimitiveDateTime;"), + TimeCrate::Chrono => { + RustType::with_import("NaiveDateTime", "use chrono::NaiveDateTime;") + } + TimeCrate::Time => { + RustType::with_import("PrimitiveDateTime", "use time::PrimitiveDateTime;") + } }; } if upper.contains("DATE") { @@ -38,7 +42,9 @@ pub fn map_type(declared_type: &str, time_crate: TimeCrate) -> RustType { }; } if upper.contains("NUMERIC") || upper.contains("DECIMAL") { - return RustType::simple("f64"); + // f64 would silently lose precision for currency-style values. sqlx exposes + // the same Decimal type for sqlite as for postgres/mysql. + return RustType::with_import("Decimal", "use rust_decimal::Decimal;"); } // Default: SQLite is loosely typed @@ -164,13 +170,16 @@ mod tests { } #[test] - fn test_numeric() { - assert_eq!(map_type("NUMERIC", TimeCrate::Chrono).path, "f64"); + fn test_numeric_uses_decimal() { + let rt = map_type("NUMERIC", TimeCrate::Chrono); + assert_eq!(rt.path, "Decimal"); + assert!(rt.needs_import.as_ref().unwrap().contains("rust_decimal")); } #[test] - fn test_decimal() { - assert_eq!(map_type("DECIMAL", TimeCrate::Chrono).path, "f64"); + fn test_decimal_uses_decimal() { + let rt = map_type("DECIMAL", TimeCrate::Chrono); + assert_eq!(rt.path, "Decimal"); } #[test] @@ -189,14 +198,22 @@ mod tests { fn test_timestamp_time_crate() { let rt = map_type("TIMESTAMP", TimeCrate::Time); assert_eq!(rt.path, "PrimitiveDateTime"); - assert!(rt.needs_import.as_ref().unwrap().contains("time::PrimitiveDateTime")); + assert!(rt + .needs_import + .as_ref() + .unwrap() + .contains("time::PrimitiveDateTime")); } #[test] fn test_datetime_time_crate() { let rt = map_type("DATETIME", TimeCrate::Time); assert_eq!(rt.path, "PrimitiveDateTime"); - assert!(rt.needs_import.as_ref().unwrap().contains("time::PrimitiveDateTime")); + assert!(rt + .needs_import + .as_ref() + .unwrap() + .contains("time::PrimitiveDateTime")); } #[test] diff --git a/crates/sqlx_gen/src/writer.rs b/crates/sqlx_gen/src/writer.rs index ea76e0d..2bba06b 100644 --- a/crates/sqlx_gen/src/writer.rs +++ b/crates/sqlx_gen/src/writer.rs @@ -1,3 +1,4 @@ +use std::io::Write; use std::path::Path; use crate::error::Result; @@ -7,12 +8,32 @@ use crate::codegen::GeneratedFile; const COMMENT: &str = "// Auto-generated by sqlx-gen. Do not edit."; const INNER_ATTR: &str = "#![allow(unused_attributes)]"; +/// Write `content` to `path` atomically: stream to a sibling temp file then rename. +/// Avoids leaving partially-written files on Ctrl-C or disk-full errors. +pub(crate) fn write_atomic(path: &Path, content: &[u8]) -> Result<()> { + let parent = path.parent().ok_or_else(|| { + crate::error::Error::Config(format!( + "Cannot determine parent directory of {}", + path.display() + )) + })?; + let mut tmp = tempfile::NamedTempFile::new_in(parent)?; + tmp.write_all(content)?; + tmp.flush()?; + tmp.persist(path).map_err(|e| e.error)?; + Ok(()) +} + pub fn write_files( files: &[GeneratedFile], output_dir: &Path, single_file: bool, dry_run: bool, ) -> Result<()> { + for f in files { + validate_safe_filename(&f.filename)?; + } + if dry_run { for f in files { println!("{}", build_file_content(f)); @@ -32,6 +53,28 @@ pub fn write_files( Ok(()) } +/// Reject filenames that could escape the output directory (`..`, path +/// separators, absolute paths) or that aren't `.rs` files. Defends against +/// malicious DB metadata in the rare case introspected table names flow into +/// the file name. +fn validate_safe_filename(filename: &str) -> Result<()> { + let p = Path::new(filename); + if filename.is_empty() + || p.components().count() != 1 + || p.is_absolute() + || filename.contains("..") + || filename.contains('/') + || filename.contains('\\') + || !filename.ends_with(".rs") + { + return Err(crate::error::Error::Config(format!( + "Refusing to write generated file with unsafe name: {:?}", + filename + ))); + } + Ok(()) +} + fn build_file_content(f: &GeneratedFile) -> String { let mut content = String::new(); content.push_str(COMMENT); @@ -58,7 +101,7 @@ fn write_single_file(files: &[GeneratedFile], output_dir: &Path) -> Result<()> { } let path = output_dir.join("models.rs"); - std::fs::write(&path, &content)?; + write_atomic(&path, content.as_bytes())?; log::info!("Wrote {}", path.display()); Ok(()) @@ -70,7 +113,7 @@ fn write_multi_files(files: &[GeneratedFile], output_dir: &Path) -> Result<()> { for f in files { let content = build_file_content(f); let path = output_dir.join(&f.filename); - std::fs::write(&path, &content)?; + write_atomic(&path, content.as_bytes())?; log::info!("Wrote {}", path.display()); let mod_name = f.filename.strip_suffix(".rs").unwrap_or(&f.filename); @@ -84,7 +127,7 @@ fn write_multi_files(files: &[GeneratedFile], output_dir: &Path) -> Result<()> { } let mod_path = output_dir.join("mod.rs"); - std::fs::write(&mod_path, &mod_content)?; + write_atomic(&mod_path, mod_content.as_bytes())?; log::info!("Wrote {}", mod_path.display()); Ok(()) @@ -107,7 +150,11 @@ mod tests { #[test] fn test_build_content_with_origin() { - let f = make_file("users.rs", "pub struct Users {}", Some("Table: public.users")); + let f = make_file( + "users.rs", + "pub struct Users {}", + Some("Table: public.users"), + ); let content = build_file_content(&f); assert!(content.contains(COMMENT)); assert!(content.contains(INNER_ATTR)); @@ -188,8 +235,16 @@ mod tests { #[test] fn test_multi_creates_files_and_mod() { let files = vec![ - make_file("users.rs", "pub struct Users {}", Some("Table: public.users")), - make_file("posts.rs", "pub struct Posts {}", Some("Table: public.posts")), + make_file( + "users.rs", + "pub struct Users {}", + Some("Table: public.users"), + ), + make_file( + "posts.rs", + "pub struct Posts {}", + Some("Table: public.posts"), + ), ]; let dir = tempfile::tempdir().unwrap(); write_files(&files, dir.path(), false, false).unwrap(); @@ -259,7 +314,11 @@ mod tests { #[test] fn test_single_creates_models_rs() { - let files = vec![make_file("users.rs", "pub struct Users {}", Some("Table: public.users"))]; + let files = vec![make_file( + "users.rs", + "pub struct Users {}", + Some("Table: public.users"), + )]; let dir = tempfile::tempdir().unwrap(); write_files(&files, dir.path(), true, false).unwrap(); assert!(dir.path().join("models.rs").exists()); @@ -308,4 +367,80 @@ mod tests { let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap(); assert!(!content.contains("// ---")); } + + // ========== write_atomic ========== + + #[test] + fn test_atomic_creates_file_with_content() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("out.rs"); + write_atomic(&path, b"hello").unwrap(); + assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello"); + } + + #[test] + fn test_atomic_overwrites_existing_file() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("out.rs"); + std::fs::write(&path, "old").unwrap(); + write_atomic(&path, b"new").unwrap(); + assert_eq!(std::fs::read_to_string(&path).unwrap(), "new"); + } + + #[test] + fn test_atomic_leaves_no_temp_artifacts_on_success() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("out.rs"); + write_atomic(&path, b"x").unwrap(); + let entries: Vec<_> = std::fs::read_dir(dir.path()) + .unwrap() + .map(|e| e.unwrap().file_name()) + .collect(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].to_string_lossy(), "out.rs"); + } + + // ========== validate_safe_filename ========== + + #[test] + fn test_rejects_dot_dot_in_filename() { + let files = vec![make_file("../escape.rs", "code", None)]; + let dir = tempfile::tempdir().unwrap(); + assert!(write_files(&files, dir.path(), false, false).is_err()); + } + + #[test] + fn test_rejects_absolute_path_filename() { + let files = vec![make_file("/etc/passwd", "code", None)]; + let dir = tempfile::tempdir().unwrap(); + assert!(write_files(&files, dir.path(), false, false).is_err()); + } + + #[test] + fn test_rejects_path_separator_in_filename() { + let files = vec![make_file("sub/dir/file.rs", "code", None)]; + let dir = tempfile::tempdir().unwrap(); + assert!(write_files(&files, dir.path(), false, false).is_err()); + } + + #[test] + fn test_rejects_non_rs_extension() { + let files = vec![make_file("evil.sh", "code", None)]; + let dir = tempfile::tempdir().unwrap(); + assert!(write_files(&files, dir.path(), false, false).is_err()); + } + + #[test] + fn test_rejects_empty_filename() { + let files = vec![make_file("", "code", None)]; + let dir = tempfile::tempdir().unwrap(); + assert!(write_files(&files, dir.path(), false, false).is_err()); + } + + #[test] + fn test_accepts_normal_rs_filename() { + let files = vec![make_file("users.rs", "code", None)]; + let dir = tempfile::tempdir().unwrap(); + assert!(write_files(&files, dir.path(), false, false).is_ok()); + } } diff --git a/crates/sqlx_gen/tests/compile_check.rs b/crates/sqlx_gen/tests/compile_check.rs new file mode 100644 index 0000000..7f86c4a --- /dev/null +++ b/crates/sqlx_gen/tests/compile_check.rs @@ -0,0 +1,288 @@ +//! Validate that codegen output is syntactically and semantically loadable. +//! +//! The fast path runs on every CI build: each generated file is parsed with +//! `syn::parse_file` (proves the output is valid Rust at the AST level). +//! +//! The deep path runs only when `SQLX_GEN_COMPILE_CHECK=1` is set in the env: +//! it scaffolds a temporary downstream crate, writes the generated files into +//! `src/lib.rs`, and runs `cargo check` against the upstream `sqlx-gen` crate +//! plus `sqlx`. That confirms the emitted attributes, derives, and imports are +//! genuinely accepted by `sqlx::FromRow`, `sqlx::Type`, and friends — not just +//! shaped like valid Rust. + +use std::collections::HashMap; +use std::path::Path; + +use sqlx_gen::cli::{DatabaseKind, DomainStyle, TimeCrate}; +use sqlx_gen::codegen::{generate_with_domain_style, GeneratedFile}; +use sqlx_gen::introspect::{ + ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo, +}; + +fn rich_schema() -> SchemaInfo { + SchemaInfo { + tables: vec![TableInfo { + schema_name: "public".to_string(), + name: "users".to_string(), + columns: vec![ + column("id", "int4", false, true, None), + column("email", "text", false, false, None), + column("name", "text", true, false, None), + column("status", "status", false, false, None), + column("metadata", "jsonb", true, false, None), + ], + }], + views: vec![TableInfo { + schema_name: "public".to_string(), + name: "active_users".to_string(), + columns: vec![ + column("id", "int4", false, false, None), + column("email", "text", false, false, None), + ], + }], + enums: vec![EnumInfo { + schema_name: "public".to_string(), + name: "status".to_string(), + variants: vec!["active".to_string(), "inactive".to_string()], + default_variant: Some("active".to_string()), + }], + composite_types: vec![CompositeTypeInfo { + schema_name: "public".to_string(), + name: "address".to_string(), + fields: vec![ + column("street", "text", false, false, None), + column("city", "text", false, false, None), + ], + }], + domains: vec![DomainInfo { + schema_name: "public".to_string(), + name: "email".to_string(), + base_type: "text".to_string(), + }], + } +} + +fn column(name: &str, udt: &str, nullable: bool, pk: bool, default: Option<&str>) -> ColumnInfo { + ColumnInfo { + name: name.to_string(), + data_type: udt.to_string(), + udt_name: udt.to_string(), + udt_schema: None, + is_nullable: nullable, + is_primary_key: pk, + ordinal_position: 0, + schema_name: "public".to_string(), + column_default: default.map(|s| s.to_string()), + } +} + +fn parse_each_file(files: &[GeneratedFile]) { + for f in files { + syn::parse_file(&f.code).unwrap_or_else(|e| { + panic!( + "generated file '{}' is not syntactically valid Rust: {}\n--- BEGIN ---\n{}\n--- END ---", + f.filename, e, f.code + ) + }); + } +} + +#[test] +fn generated_postgres_files_parse() { + let files = generate_with_domain_style( + &rich_schema(), + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + DomainStyle::Alias, + ) + .expect("codegen"); + parse_each_file(&files); +} + +#[test] +fn generated_mysql_files_parse() { + let schema = SchemaInfo { + // MySQL has no enums/composites/domains in our model, just tables. + tables: rich_schema() + .tables + .into_iter() + .map(|mut t| { + t.columns + .retain(|c| c.udt_name != "status" && c.udt_name != "jsonb"); + t + }) + .collect(), + views: vec![], + enums: vec![], + composite_types: vec![], + domains: vec![], + }; + let files = generate_with_domain_style( + &schema, + DatabaseKind::Mysql, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + DomainStyle::Alias, + ) + .expect("codegen"); + parse_each_file(&files); +} + +#[test] +fn generated_sqlite_files_parse() { + let schema = SchemaInfo { + tables: vec![TableInfo { + schema_name: "main".to_string(), + name: "users".to_string(), + columns: vec![ + column("id", "INTEGER", false, true, None), + column("name", "TEXT", false, false, None), + ], + }], + views: vec![], + enums: vec![], + composite_types: vec![], + domains: vec![], + }; + let files = generate_with_domain_style( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + DomainStyle::Alias, + ) + .expect("codegen"); + parse_each_file(&files); +} + +#[test] +fn generated_postgres_files_parse_with_newtype_domain() { + let files = generate_with_domain_style( + &rich_schema(), + DatabaseKind::Postgres, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + DomainStyle::Newtype, + ) + .expect("codegen"); + parse_each_file(&files); + // Ensure the newtype actually emitted a tuple struct (not a type alias). + let types_rs = files + .iter() + .find(|f| f.filename == "types.rs") + .expect("types.rs file should be emitted"); + assert!( + types_rs.code.contains("pub struct Email(pub String)"), + "newtype domain not found in:\n{}", + types_rs.code + ); +} + +/// Deep check: actually run `cargo check` against a real downstream crate. +/// Skipped unless `SQLX_GEN_COMPILE_CHECK=1` is set, because it pulls in +/// the full sqlx dependency tree (~1 min on a cold cargo cache). +#[test] +fn generated_files_pass_cargo_check_in_downstream_crate() { + if std::env::var("SQLX_GEN_COMPILE_CHECK").as_deref() != Ok("1") { + eprintln!("skipped (set SQLX_GEN_COMPILE_CHECK=1 to enable)"); + return; + } + + let files = generate_with_domain_style( + &rich_schema(), + DatabaseKind::Postgres, + &[], + &HashMap::new(), + true, // single_file = true so we emit one models.rs + TimeCrate::Chrono, + DomainStyle::Alias, + ) + .expect("codegen"); + + let dir = tempfile::tempdir().expect("temp dir"); + let project_root = workspace_root(); + let sqlx_gen_path = project_root.join("crates/sqlx_gen"); + + std::fs::write( + dir.path().join("Cargo.toml"), + format!( + r#" +[package] +name = "sqlx_gen_compile_check" +version = "0.0.0" +edition = "2021" + +[lib] +path = "src/lib.rs" + +[dependencies] +sqlx = {{ version = "0.8", default-features = false, features = [ + "runtime-tokio", "tls-rustls-ring", "postgres", "chrono", "uuid", "json", +] }} +sqlx_gen = {{ path = "{}", default-features = false }} +serde = {{ version = "1", features = ["derive"] }} +chrono = "0.4" +uuid = "1" +serde_json = "1" +rust_decimal = "1" +ipnetwork = "0.20" +"#, + sqlx_gen_path.display(), + ), + ) + .unwrap(); + std::fs::create_dir_all(dir.path().join("src")).unwrap(); + // Concatenate all generated files into a single lib.rs. + let mut lib = String::new(); + lib.push_str("#![allow(unused_imports, dead_code, unused_attributes)]\n\n"); + for f in &files { + lib.push_str(&f.code); + lib.push_str("\n\n"); + } + std::fs::write(dir.path().join("src/lib.rs"), &lib).unwrap(); + + let status = std::process::Command::new("cargo") + .arg("check") + .arg("--offline") + .current_dir(dir.path()) + .status() + .or_else(|_| { + std::process::Command::new("cargo") + .arg("check") + .current_dir(dir.path()) + .status() + }) + .expect("invoke cargo check"); + + assert!( + status.success(), + "generated code did not pass `cargo check` in a downstream crate" + ); +} + +fn workspace_root() -> std::path::PathBuf { + let mut p = Path::new(env!("CARGO_MANIFEST_DIR")).to_path_buf(); + // CARGO_MANIFEST_DIR is the sqlx_gen crate; walk up until we hit the workspace Cargo.toml. + while !p.join("Cargo.toml").exists() + || !std::fs::read_to_string(p.join("Cargo.toml")) + .map(|c| c.contains("[workspace]")) + .unwrap_or(false) + { + if !p.pop() { + panic!( + "could not locate workspace root from {}", + env!("CARGO_MANIFEST_DIR") + ); + } + } + p +} diff --git a/crates/sqlx_gen/tests/e2e_sqlite.rs b/crates/sqlx_gen/tests/e2e_sqlite.rs index 8410513..7d9ec72 100644 --- a/crates/sqlx_gen/tests/e2e_sqlite.rs +++ b/crates/sqlx_gen/tests/e2e_sqlite.rs @@ -16,9 +16,21 @@ async fn exec(pool: &SqlitePool, sql: &str) { #[tokio::test] async fn test_simple_table_generates_struct() { let pool = setup_pool().await; - exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)").await; + exec( + &pool, + "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)", + ) + .await; let schema = introspect(&pool, false).await.unwrap(); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert!(files[0].code.contains("pub struct")); } @@ -27,7 +39,15 @@ async fn test_struct_name_pascal_case() { let pool = setup_pool().await; exec(&pool, "CREATE TABLE user_profiles (id INTEGER NOT NULL)").await; let schema = introspect(&pool, false).await.unwrap(); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert!(files[0].code.contains("pub struct UserProfiles")); } @@ -36,7 +56,15 @@ async fn test_integer_mapped_to_i64() { let pool = setup_pool().await; exec(&pool, "CREATE TABLE t (id INTEGER NOT NULL)").await; let schema = introspect(&pool, false).await.unwrap(); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert!(files[0].code.contains("i64")); } @@ -45,7 +73,15 @@ async fn test_nullable_column_option() { let pool = setup_pool().await; exec(&pool, "CREATE TABLE t (name TEXT)").await; let schema = introspect(&pool, false).await.unwrap(); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert!(files[0].code.contains("Option<")); } @@ -55,7 +91,15 @@ async fn test_multiple_tables_multiple_files() { exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL)").await; exec(&pool, "CREATE TABLE posts (id INTEGER NOT NULL)").await; let schema = introspect(&pool, false).await.unwrap(); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files.len(), 2); } @@ -64,18 +108,42 @@ async fn test_filenames_correct() { let pool = setup_pool().await; exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL)").await; let schema = introspect(&pool, false).await.unwrap(); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files[0].filename, "users.rs"); } #[tokio::test] async fn test_generated_code_parseable() { let pool = setup_pool().await; - exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)").await; + exec( + &pool, + "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)", + ) + .await; let schema = introspect(&pool, false).await.unwrap(); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); for f in &files { - assert!(syn::parse_file(&f.code).is_ok(), "Failed to parse {}", f.filename); + assert!( + syn::parse_file(&f.code).is_ok(), + "Failed to parse {}", + f.filename + ); } } @@ -85,7 +153,15 @@ async fn test_extra_derives_propagated() { exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL)").await; let schema = introspect(&pool, false).await.unwrap(); let derives = vec!["Serialize".to_string()]; - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &derives, &HashMap::new(), false, TimeCrate::Chrono); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &derives, + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert!(files[0].code.contains("Serialize")); } @@ -94,11 +170,30 @@ async fn test_extra_derives_propagated() { #[tokio::test] async fn test_view_generates_struct() { let pool = setup_pool().await; - exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)").await; - exec(&pool, "CREATE VIEW active_users AS SELECT id, name FROM users").await; + exec( + &pool, + "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)", + ) + .await; + exec( + &pool, + "CREATE VIEW active_users AS SELECT id, name FROM users", + ) + .await; let schema = introspect(&pool, true).await.unwrap(); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); - let view_file = files.iter().find(|f| f.filename == "active_users.rs").unwrap(); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); + let view_file = files + .iter() + .find(|f| f.filename == "active_users.rs") + .unwrap(); assert!(view_file.code.contains("pub struct ActiveUsers")); } @@ -108,7 +203,15 @@ async fn test_view_origin_contains_view() { exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL)").await; exec(&pool, "CREATE VIEW v AS SELECT id FROM users").await; let schema = introspect(&pool, true).await.unwrap(); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); let view_file = files.iter().find(|f| f.filename == "v.rs").unwrap(); assert_eq!(view_file.origin, None); } @@ -116,12 +219,28 @@ async fn test_view_origin_contains_view() { #[tokio::test] async fn test_view_code_parseable() { let pool = setup_pool().await; - exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)").await; + exec( + &pool, + "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)", + ) + .await; exec(&pool, "CREATE VIEW user_view AS SELECT id, name FROM users").await; let schema = introspect(&pool, true).await.unwrap(); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); for f in &files { - assert!(syn::parse_file(&f.code).is_ok(), "Failed to parse {}", f.filename); + assert!( + syn::parse_file(&f.code).is_ok(), + "Failed to parse {}", + f.filename + ); } } @@ -129,10 +248,25 @@ async fn test_view_code_parseable() { async fn test_view_pascal_case_name() { let pool = setup_pool().await; exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL)").await; - exec(&pool, "CREATE VIEW all_active_users AS SELECT id FROM users").await; + exec( + &pool, + "CREATE VIEW all_active_users AS SELECT id FROM users", + ) + .await; let schema = introspect(&pool, true).await.unwrap(); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); - let view_file = files.iter().find(|f| f.filename == "all_active_users.rs").unwrap(); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); + let view_file = files + .iter() + .find(|f| f.filename == "all_active_users.rs") + .unwrap(); assert!(view_file.code.contains("pub struct AllActiveUsers")); } @@ -146,7 +280,15 @@ async fn test_exclude_table() { let mut schema = introspect(&pool, false).await.unwrap(); let exclude = ["_migrations".to_string()]; schema.tables.retain(|t| !exclude.contains(&t.name)); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); assert_eq!(files.len(), 1); assert_eq!(files[0].filename, "users.rs"); } @@ -188,8 +330,19 @@ async fn test_exclude_view() { let mut schema = introspect(&pool, true).await.unwrap(); let exclude = ["v1".to_string()]; schema.views.retain(|v| !exclude.contains(&v.name)); - let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false, TimeCrate::Chrono); - let view_files: Vec<_> = files.iter().filter(|f| f.code.contains("kind = \"view\"")).collect(); + let files = codegen::generate( + &schema, + DatabaseKind::Sqlite, + &[], + &HashMap::new(), + false, + TimeCrate::Chrono, + ) + .unwrap(); + let view_files: Vec<_> = files + .iter() + .filter(|f| f.code.contains("kind = \"view\"")) + .collect(); assert_eq!(view_files.len(), 1); assert_eq!(view_files[0].filename, "v2.rs"); } diff --git a/crates/sqlx_gen/tests/introspect_sqlite.rs b/crates/sqlx_gen/tests/introspect_sqlite.rs index 092b982..0d7f712 100644 --- a/crates/sqlx_gen/tests/introspect_sqlite.rs +++ b/crates/sqlx_gen/tests/introspect_sqlite.rs @@ -44,7 +44,11 @@ async fn test_empty_db_no_domains() { #[tokio::test] async fn test_one_table_two_columns() { let pool = setup_pool().await; - exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)").await; + exec( + &pool, + "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)", + ) + .await; let schema = introspect(&pool, false).await.unwrap(); assert_eq!(schema.tables.len(), 1); assert_eq!(schema.tables[0].columns.len(), 2); @@ -69,9 +73,17 @@ async fn test_schema_name_main() { #[tokio::test] async fn test_column_names_and_order() { let pool = setup_pool().await; - exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL, email TEXT)").await; + exec( + &pool, + "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL, email TEXT)", + ) + .await; let schema = introspect(&pool, false).await.unwrap(); - let cols: Vec<&str> = schema.tables[0].columns.iter().map(|c| c.name.as_str()).collect(); + let cols: Vec<&str> = schema.tables[0] + .columns + .iter() + .map(|c| c.name.as_str()) + .collect(); assert_eq!(cols, vec!["id", "name", "email"]); } @@ -120,8 +132,16 @@ async fn test_multiple_tables_sorted() { #[tokio::test] async fn test_view_introspected_with_flag() { let pool = setup_pool().await; - exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)").await; - exec(&pool, "CREATE VIEW active_users AS SELECT id, name FROM users").await; + exec( + &pool, + "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)", + ) + .await; + exec( + &pool, + "CREATE VIEW active_users AS SELECT id, name FROM users", + ) + .await; let schema = introspect(&pool, true).await.unwrap(); assert_eq!(schema.views.len(), 1); assert_eq!(schema.views[0].name, "active_users"); @@ -130,10 +150,22 @@ async fn test_view_introspected_with_flag() { #[tokio::test] async fn test_view_columns_correct() { let pool = setup_pool().await; - exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)").await; - exec(&pool, "CREATE VIEW user_names AS SELECT id, name FROM users").await; + exec( + &pool, + "CREATE TABLE users (id INTEGER NOT NULL, name TEXT NOT NULL)", + ) + .await; + exec( + &pool, + "CREATE VIEW user_names AS SELECT id, name FROM users", + ) + .await; let schema = introspect(&pool, true).await.unwrap(); - let cols: Vec<&str> = schema.views[0].columns.iter().map(|c| c.name.as_str()).collect(); + let cols: Vec<&str> = schema.views[0] + .columns + .iter() + .map(|c| c.name.as_str()) + .collect(); assert_eq!(cols, vec!["id", "name"]); } diff --git a/crates/sqlx_gen_macros/Cargo.toml b/crates/sqlx_gen_macros/Cargo.toml index c1d7015..850b40d 100644 --- a/crates/sqlx_gen_macros/Cargo.toml +++ b/crates/sqlx_gen_macros/Cargo.toml @@ -1,12 +1,13 @@ [package] name = "sqlx-gen-macros" -version = "0.5.5" -edition = "2021" -description = "No-op attribute macros for sqlx-gen generated code" -license = "MIT" -repository = "https://github.com/LeadcodeDev/sqlx-gen" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true keywords = ["sqlx", "codegen", "macros"] -categories = ["database", "development-tools"] +categories.workspace = true +description = "No-op attribute macros for sqlx-gen generated code" [lib] proc-macro = true