From 07430b06101dc258e49086eafc7b28690f59ed21 Mon Sep 17 00:00:00 2001 From: coord_e Date: Thu, 30 Apr 2026 11:48:35 +0900 Subject: [PATCH 1/3] Support immutable references in formula_fn --- src/analyze/annot_fn.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index 4e7c396..8be6dc5 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -333,10 +333,13 @@ impl<'tcx> AnnotFnTranslator<'tcx> { } rustc_hir::UnOp::Deref => { let operand_ty = self.expr_ty(operand); + let term = self.to_term(operand); + if matches!(operand_ty.kind(), mir_ty::TyKind::Ref(_, _, mir_ty::Mutability::Not)) { + return FormulaOrTerm::Term(term.box_current()); + } let adt = operand_ty .ty_adt_def() .expect("deref operand must be a model type"); - let term = self.to_term(operand); if Some(adt.did()) == self.def_ids.mut_model() { FormulaOrTerm::Term(term.mut_current()) } else if Some(adt.did()) == self.def_ids.box_model() { @@ -349,6 +352,10 @@ impl<'tcx> AnnotFnTranslator<'tcx> { } } }, + ExprKind::AddrOf(rustc_hir::BorrowKind::Ref, rustc_hir::Mutability::Not, operand) => { + let operand = self.to_term(operand); + FormulaOrTerm::Term(operand.boxed()) + } ExprKind::Lit(lit) => match lit.node { rustc_ast::LitKind::Int(i, _) => { let n = i64::try_from(i.get()) From 404350bf384df43ed7fadee072dee8d5c46c57d8 Mon Sep 17 00:00:00 2001 From: coord_e Date: Thu, 30 Apr 2026 11:48:57 +0900 Subject: [PATCH 2/3] Support array operations in formula_fn --- src/analyze/annot.rs | 8 ++++++++ src/analyze/annot_fn.rs | 41 ++++++++++++++++++++++++++++++---------- src/analyze/did_cache.rs | 8 ++++++++ std.rs | 31 +++++++++++++++++++++++++++++- 4 files changed, 77 insertions(+), 11 deletions(-) diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 3163803..b94a1ef 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -125,6 +125,14 @@ pub fn box_model_new_path() -> [Symbol; 3] { ] } +pub fn array_model_store_path() -> [Symbol; 3] { + [ + Symbol::intern("thrust"), + Symbol::intern("def"), + Symbol::intern("array_store"), + ] +} + pub fn exists_path() -> [Symbol; 3] { [ Symbol::intern("thrust"), diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index 8be6dc5..637ac6d 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -393,6 +393,23 @@ impl<'tcx> AnnotFnTranslator<'tcx> { let term = self.to_term(expr); FormulaOrTerm::Term(term.tuple_proj(index)) } + ExprKind::Index(array, index, _) => { + let array_term = self.to_term(array); + let index_term = self.to_term(index); + FormulaOrTerm::Term(array_term.select(index_term)) + } + ExprKind::MethodCall(method, receiver, args, _) => { + if let Some(def_id) = self.typeck.type_dependent_def_id(hir.hir_id) { + if Some(def_id) == self.def_ids.array_model_store() { + assert_eq!(args.len(), 2, "array_store takes exactly 2 arguments"); + let array_term = self.to_term(receiver); + let index_term = self.to_term(&args[0]); + let value_term = self.to_term(&args[1]); + return FormulaOrTerm::Term(array_term.store(index_term, value_term)); + } + } + unimplemented!("unsupported method call in formula: {:?}", method) + } ExprKind::Call(func_expr, args) => { if let ExprKind::Path(qpath) = &func_expr.kind { let res = self.typeck.qpath_res(qpath, func_expr.hir_id); @@ -441,16 +458,20 @@ impl<'tcx> AnnotFnTranslator<'tcx> { let t = self.to_term(&args[0]); return FormulaOrTerm::Term(chc::Term::box_(t)); } - if matches!( - def_kind, - rustc_hir::def::DefKind::Ctor(rustc_hir::def::CtorOf::Variant, _) - ) { - let field_terms = args.iter().map(|arg| self.to_term(arg)).collect(); - return FormulaOrTerm::Term(self.variant_ctor_term( - def_id, - self.expr_ty(hir), - field_terms, - )); + if let rustc_hir::def::DefKind::Ctor(ctor_of, _) = def_kind { + let terms = args.iter().map(|e| self.to_term(e)).collect(); + match ctor_of { + rustc_hir::def::CtorOf::Variant => { + return FormulaOrTerm::Term(self.variant_ctor_term( + def_id, + self.expr_ty(hir), + terms, + )); + } + rustc_hir::def::CtorOf::Struct => { + return FormulaOrTerm::Term(chc::Term::tuple(terms)); + } + } } } } diff --git a/src/analyze/did_cache.rs b/src/analyze/did_cache.rs index 99b6e4c..bb6fb90 100644 --- a/src/analyze/did_cache.rs +++ b/src/analyze/did_cache.rs @@ -22,6 +22,7 @@ struct DefIds { mut_model_new: OnceCell>, box_model_new: OnceCell>, + array_model_store: OnceCell>, exists: OnceCell>, } @@ -163,6 +164,13 @@ impl<'tcx> DefIdCache<'tcx> { .get_or_init(|| self.annotated_def(&crate::analyze::annot::box_model_new_path())) } + pub fn array_model_store(&self) -> Option { + *self + .def_ids + .array_model_store + .get_or_init(|| self.annotated_def(&crate::analyze::annot::array_model_store_path())) + } + pub fn exists(&self) -> Option { *self .def_ids diff --git a/std.rs b/std.rs index 3318656..6d6718c 100644 --- a/std.rs +++ b/std.rs @@ -137,10 +137,35 @@ mod thrust_models { } } + impl std::ops::Index for Array { + type Output = T; + + #[thrust::ignored] + fn index(&self, _index: I) -> &Self::Output { + unimplemented!() + } + } + + impl Array { + #[allow(dead_code)] + #[thrust::def::array_store] + #[thrust::ignored] + pub fn store(&self, _index: I, _value: T) -> Self { + unimplemented!() + } + } + #[thrust::def::closure_model] pub struct Closure(PhantomData); - pub struct Vec(pub Array, pub usize); + pub struct Vec(pub Array, pub Int); + + impl PartialEq for Vec where U: super::Model { + #[thrust::ignored] + fn eq(&self, _other: &U) -> bool { + unimplemented!() + } + } } impl Model for model::Int { @@ -219,6 +244,10 @@ mod thrust_models { type Ty = model::Box; } + impl Model for model::Array { + type Ty = model::Array; + } + impl Model for Vec where T: Model { type Ty = model::Vec<::Ty>; } From 3d7a24a778e6032c1df24722185818ce6342237b Mon Sep 17 00:00:00 2001 From: coord_e Date: Thu, 30 Apr 2026 11:49:23 +0900 Subject: [PATCH 3/3] fixup! Support immutable references in formula_fn --- src/analyze/annot_fn.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index 637ac6d..9cc54d4 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -334,7 +334,10 @@ impl<'tcx> AnnotFnTranslator<'tcx> { rustc_hir::UnOp::Deref => { let operand_ty = self.expr_ty(operand); let term = self.to_term(operand); - if matches!(operand_ty.kind(), mir_ty::TyKind::Ref(_, _, mir_ty::Mutability::Not)) { + if matches!( + operand_ty.kind(), + mir_ty::TyKind::Ref(_, _, mir_ty::Mutability::Not) + ) { return FormulaOrTerm::Term(term.box_current()); } let adt = operand_ty