diff --git a/Cargo.lock b/Cargo.lock index ba30fe2f4d..b4e24f95e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -821,9 +821,9 @@ checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" [[package]] name = "linux-raw-sys" -version = "0.4.10" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" +checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" [[package]] name = "lock_api" @@ -1177,9 +1177,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.19" +version = "0.38.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "745ecfa778e66b2b63c88a61cb36e0eea109e803b0b86bf9879fbc77c70e86ed" +checksum = "0a962918ea88d644592894bc6dc55acc6c0956488adcebbfb6e273506b7fd6e5" dependencies = [ "bitflags 2.4.0", "errno", diff --git a/crates/flux-desugar/src/desugar.rs b/crates/flux-desugar/src/desugar.rs index 36bd741ca5..3f158e1847 100644 --- a/crates/flux-desugar/src/desugar.rs +++ b/crates/flux-desugar/src/desugar.rs @@ -1,4 +1,4 @@ -use std::{borrow::Borrow, iter, slice}; +use std::{borrow::Borrow, iter}; use flux_common::{bug, index::IndexGen, iter::IterExt, span_bug}; use flux_errors::FluxSession; @@ -6,9 +6,11 @@ use flux_middle::{ fhir::{self, lift::LiftCtxt, ExprKind, FhirId, FluxOwnerId, Res}, global_env::GlobalEnv, intern::List, + rty::{self}, }; use flux_syntax::surface; -use hir::{def::DefKind, ItemKind}; +use hir::{def::DefKind, ItemKind, PrimTy}; +use itertools::Itertools; use rustc_data_structures::{ fx::{FxIndexMap, IndexEntry}, unord::UnordMap, @@ -17,6 +19,7 @@ use rustc_errors::{ErrorGuaranteed, IntoDiagnostic}; use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hir as hir; use rustc_hir::OwnerId; +use rustc_middle::ty::Generics; use rustc_span::{ def_id::LocalDefId, sym::{self}, @@ -32,7 +35,11 @@ pub fn desugar_qualifier( genv: &GlobalEnv, qualifier: &surface::Qualifier, ) -> Result { - let mut binders = Binders::from_params(genv, &qualifier.args)?; + let sort_params = &[]; + let sort_resolver = + SortResolver::with_sort_params(genv.sess, genv.map().sort_decls(), sort_params); + + let mut binders = Binders::from_params(genv, &sort_resolver, &qualifier.args)?; let index_gen = IndexGen::new(); let cx = ExprCtxt::new(genv, FluxOwnerId::Flux(qualifier.name.name), &index_gen); let expr = cx.desugar_expr(&binders, &qualifier.expr); @@ -47,14 +54,18 @@ pub fn desugar_qualifier( pub fn desugar_defn(genv: &GlobalEnv, defn: surface::FuncDef) -> Result> { if let Some(body) = defn.body { - let mut binders = Binders::from_params(genv, &defn.args)?; + let sort_params = defn.sort_vars.iter().map(|ident| ident.name).collect_vec(); + let sort_resolver = + SortResolver::with_sort_params(genv.sess, genv.map().sort_decls(), &sort_params); + let mut binders = Binders::from_params(genv, &sort_resolver, &defn.args)?; let local_id_gen = IndexGen::new(); let cx = ExprCtxt::new(genv, FluxOwnerId::Flux(defn.name.name), &local_id_gen); let expr = cx.desugar_expr(&binders, &body)?; let name = defn.name.name; - let sort = resolve_sort(genv.sess, genv.map().sort_decls(), &defn.output)?; + let params = defn.sort_vars.len(); + let sort = sort_resolver.resolve_sort(&defn.output)?; let args = binders.pop_layer().into_params(&cx); - Ok(Some(fhir::Defn { name, args, sort, expr })) + Ok(Some(fhir::Defn { name, params, args, sort, expr })) } else { Ok(None) } @@ -65,21 +76,76 @@ pub fn func_def_to_func_decl( sort_decls: &fhir::SortDecls, defn: &surface::FuncDef, ) -> Result { + let params = defn.sort_vars.len(); + let sort_vars = defn.sort_vars.iter().map(|ident| ident.name).collect_vec(); + let sr = SortResolver::with_sort_params(sess, sort_decls, &sort_vars); let inputs: Vec = defn .args .iter() - .map(|arg| resolve_sort(sess, sort_decls, &arg.sort)) + .map(|arg| sr.resolve_sort(&arg.sort)) .try_collect_exhaust()?; - let output = resolve_sort(sess, sort_decls, &defn.output)?; - let sort = fhir::PolyFuncSort::new(0, inputs, output); + let output = sr.resolve_sort(&defn.output)?; + let sort = fhir::PolyFuncSort::new(params, inputs, output); let kind = if defn.body.is_some() { fhir::FuncKind::Def } else { fhir::FuncKind::Uif }; Ok(fhir::FuncDecl { name: defn.name.name, sort, kind }) } +fn gather_base_sort_vars( + generics: &FxHashSet, + base_sort: &surface::BaseSort, + sort_vars: &mut FxHashSet, +) { + match base_sort { + surface::BaseSort::Ident(x) => { + if generics.contains(&x.name) { + sort_vars.insert(x.name); + } + } + surface::BaseSort::BitVec(_) => {} + surface::BaseSort::App(_, base_sorts) => { + for base_sort in base_sorts { + gather_base_sort_vars(generics, base_sort, sort_vars); + } + } + } +} +fn gather_sort_vars( + generics: &FxHashSet, + sort: &surface::Sort, + sort_vars: &mut FxHashSet, +) { + match sort { + surface::Sort::Base(base_sort) => gather_base_sort_vars(generics, base_sort, sort_vars), + surface::Sort::Func { inputs, output } => { + for base_sort in inputs { + gather_base_sort_vars(generics, base_sort, sort_vars); + } + gather_base_sort_vars(generics, output, sort_vars); + } + surface::Sort::Infer => {} + } +} + +fn gather_refined_by_sort_vars( + generics: &rustc_middle::ty::Generics, + refined_by: &surface::RefinedBy, +) -> Vec { + let generics_syms: FxHashSet = generics.params.iter().map(|param| param.name).collect(); + let mut sort_idents = FxHashSet::default(); + for refine_param in &refined_by.index_params { + gather_sort_vars(&generics_syms, &refine_param.sort, &mut sort_idents); + } + generics + .params + .iter() + .filter_map(|param| if sort_idents.contains(¶m.name) { Some(param.name) } else { None }) + .collect() +} pub fn desugar_refined_by( sess: &FluxSession, sort_decls: &fhir::SortDecls, owner_id: OwnerId, + generics: &rustc_middle::ty::Generics, refined_by: &surface::RefinedBy, ) -> Result { let mut set = FxHashSet::default(); @@ -90,19 +156,36 @@ pub fn desugar_refined_by( Ok(()) } })?; + + let sort_vars = gather_refined_by_sort_vars(generics, refined_by); + let sr = SortResolver::with_sort_params(sess, sort_decls, &sort_vars); + let early_bound_params: Vec<_> = refined_by .early_bound_params .iter() - .map(|param| resolve_sort(sess, sort_decls, ¶m.sort)) + .map(|param| sr.resolve_sort(¶m.sort)) .try_collect_exhaust()?; let index_params: Vec<_> = refined_by .index_params .iter() - .map(|param| Ok((param.name.name, resolve_sort(sess, sort_decls, ¶m.sort)?))) + .map(|param| Ok((param.name.name, sr.resolve_sort(¶m.sort)?))) .try_collect_exhaust()?; - Ok(fhir::RefinedBy::new(owner_id.def_id, early_bound_params, index_params, refined_by.span)) + let generic_idx: FxHashMap = generics + .params + .iter() + .map(|param| (param.name, param.def_id)) + .collect(); + let sort_params = sort_vars.iter().map(|sym| generic_idx[&sym]).collect(); + + Ok(fhir::RefinedBy::new( + owner_id.def_id, + early_bound_params, + index_params, + sort_params, + refined_by.span, + )) } pub(crate) struct DesugarCtxt<'a, 'tcx> { @@ -111,6 +194,7 @@ pub(crate) struct DesugarCtxt<'a, 'tcx> { owner: OwnerId, resolver_output: &'a ResolverOutput, opaque_tys: Option<&'a mut UnordMap>, + sort_resolver: SortResolver<'a>, } /// Keeps track of the surface level identifiers in scope and a mapping between them and a @@ -167,7 +251,17 @@ impl<'a, 'tcx> DesugarCtxt<'a, 'tcx> { resolver_output: &'a ResolverOutput, opaque_tys: Option<&'a mut UnordMap>, ) -> DesugarCtxt<'a, 'tcx> { - DesugarCtxt { genv, owner, local_id_gen: IndexGen::new(), resolver_output, opaque_tys } + let generics = genv.tcx.generics_of(owner); + let sort_resolver = + SortResolver::with_generics(genv.sess, genv.map().sort_decls(), generics); + DesugarCtxt { + genv, + owner, + local_id_gen: IndexGen::new(), + sort_resolver, + resolver_output, + opaque_tys, + } } fn with_new_owner<'b>(&'b mut self, owner: OwnerId) -> DesugarCtxt<'b, 'tcx> { @@ -184,6 +278,7 @@ impl<'a, 'tcx> DesugarCtxt<'a, 'tcx> { pub(crate) fn desugar_generics(&self, generics: &surface::Generics) -> Result { let hir_generics = self.genv.hir().get_generics(self.owner.def_id).unwrap(); + let generics_map: FxHashMap<_, _> = hir_generics .params .iter() @@ -201,6 +296,7 @@ impl<'a, 'tcx> DesugarCtxt<'a, 'tcx> { let kind = match ¶m.kind { surface::GenericParamKind::Type => fhir::GenericParamKind::Type { default: None }, surface::GenericParamKind::Base => fhir::GenericParamKind::BaseTy, + surface::GenericParamKind::Spl => fhir::GenericParamKind::SplTy, surface::GenericParamKind::Refine { .. } => { continue; } @@ -272,6 +368,7 @@ impl<'a, 'tcx> DesugarCtxt<'a, 'tcx> { binders.push_layer(); binders.insert_params( self.genv, + &self.sort_resolver, struct_def .refined_by .iter() @@ -334,6 +431,7 @@ impl<'a, 'tcx> DesugarCtxt<'a, 'tcx> { binders.push_layer(); binders.insert_params( self.genv, + &self.sort_resolver, enum_def .refined_by .iter() @@ -394,7 +492,7 @@ impl<'a, 'tcx> DesugarCtxt<'a, 'tcx> { binders: &mut Binders, ) -> Result { binders.push_layer(); - binders.insert_params(self.genv, ty_alias.refined_by.all_params())?; + binders.insert_params(self.genv, &self.sort_resolver, ty_alias.refined_by.all_params())?; let ty = self.desugar_ty(None, &ty_alias.ty, binders)?; @@ -580,6 +678,7 @@ impl<'a, 'tcx> DesugarCtxt<'a, 'tcx> { binders: &mut Binders, ) -> Result { let span = ty.span; + let generics = self.genv.tcx.generics_of(self.owner.def_id); let kind = match &ty.kind { surface::TyKind::Base(bty) => { // CODESYNC(type-holes, 3) @@ -626,10 +725,14 @@ impl<'a, 'tcx> DesugarCtxt<'a, 'tcx> { } surface::TyKind::GeneralExists { params, ty, pred } => { binders.push_layer(); + let sr = SortResolver::with_generics( + self.sess(), + self.genv.map().sort_decls(), + generics, + ); for param in params { let fresh = binders.fresh(); - let sort = - resolve_sort(self.sess(), self.genv.map().sort_decls(), ¶m.sort)?; + let sort = sr.resolve_sort(¶m.sort)?; let binder = Binder::Refined(fresh, sort.clone(), false); binders.insert_binder(self.sess(), param.name, binder)?; } @@ -708,9 +811,9 @@ impl<'a, 'tcx> DesugarCtxt<'a, 'tcx> { self.ident_into_refine_arg(*ident, binders) .transpose() .unwrap() - } else if let Some(fhir::Sort::Record(def_id)) = self.genv.sort_of_bty(bty) { + } else if let Some(fhir::Sort::Record(def_id, sort_args)) = self.genv.sort_of_bty(bty) { let flds = self.desugar_refine_args(&idxs.indices, binders)?; - Ok(fhir::RefineArg::Record(def_id, flds, idxs.span)) + Ok(fhir::RefineArg::Record(def_id, sort_args, flds, idxs.span)) } else if let [arg] = &idxs.indices[..] { self.desugar_refine_arg(arg, binders) } else { @@ -746,7 +849,7 @@ impl<'a, 'tcx> DesugarCtxt<'a, 'tcx> { } surface::RefineArg::Abs(params, body, span) => { binders.push_layer(); - binders.insert_params(self.genv, params)?; + binders.insert_params(self.genv, &self.sort_resolver, params)?; let body = self.as_expr_ctxt().desugar_expr(binders, body)?; let params = binders.pop_layer().into_params(self); Ok(fhir::RefineArg::Abs(params, body, *span, self.next_fhir_id())) @@ -1064,8 +1167,7 @@ impl DesugarCtxt<'_, '_> { binders: &mut Binders, ) -> Result { self.gather_params_path(&ret.path, TypePos::Other, binders)?; - let res = self.resolver_output.path_res_map[&ret.path.node_id]; - let Some(sort) = self.genv.sort_of_res(res) else { + let Some(sort) = sort_of_surface_path(self.genv, self.resolver_output, &ret.path) else { return Err(self.emit_err(errors::RefinedUnrefinableType::new(ret.path.span))); }; self.gather_params_indices(sort, &ret.indices, TypePos::Other, binders) @@ -1086,16 +1188,15 @@ impl DesugarCtxt<'_, '_> { } fn gather_input_params_fn_sig(&self, fn_sig: &surface::FnSig, binders: &mut Binders) -> Result { + let generics = self.genv.tcx.generics_of(self.owner.def_id); + let sr = + SortResolver::with_generics(self.genv.sess, self.genv.map().sort_decls(), generics); for param in fn_sig.generics.iter().flat_map(|g| &g.params) { let surface::GenericParamKind::Refine { sort } = ¶m.kind else { continue }; binders.insert_binder( self.genv.sess, param.name, - Binder::Refined( - binders.fresh(), - resolve_sort(self.genv.sess, self.genv.map().sort_decls(), sort)?, - false, - ), + Binder::Refined(binders.fresh(), sr.resolve_sort(sort)?, false), )?; } for arg in &fn_sig.args { @@ -1124,12 +1225,12 @@ impl DesugarCtxt<'_, '_> { fn gather_params_fun_arg(&self, arg: &surface::Arg, binders: &mut Binders) -> Result { match arg { surface::Arg::Constr(bind, path, _) => { - let res = self.resolver_output.path_res_map[&path.node_id]; - binders.insert_binder( - self.genv.sess, - *bind, - binders.binder_from_res(self.genv, res), - )?; + let zz = sort_of_surface_path(self.genv, self.resolver_output, path); + let Some(sort) = zz else { + return Err(self.emit_err(errors::RefinedUnrefinableType::new(path.span))); + }; + + binders.insert_binder(self.genv.sess, *bind, binders.binder_from_sort(sort))?; } surface::Arg::StrgRef(loc, ty) => { binders.insert_binder( @@ -1302,91 +1403,114 @@ impl DesugarCtxt<'_, '_> { } } -fn resolve_sort( - sess: &FluxSession, - sort_decls: &fhir::SortDecls, - sort: &surface::Sort, -) -> Result { - match sort { - surface::Sort::Base(sort) => resolve_base_sort(sess, sort_decls, sort), - surface::Sort::Func { inputs, output } => { - Ok(resolve_func_sort(sess, sort_decls, inputs, output)?.into()) - } - surface::Sort::Infer => Ok(fhir::Sort::Wildcard), - } +struct SortResolver<'a> { + sess: &'a FluxSession, + sort_decls: &'a fhir::SortDecls, + generic_params: FxHashMap, + sort_params: FxHashMap, } -fn resolve_func_sort( - sess: &FluxSession, - sort_decls: &fhir::SortDecls, - inputs: &[surface::BaseSort], - output: &surface::BaseSort, -) -> Result { - let inputs: Vec = inputs - .iter() - .map(|sort| resolve_base_sort(sess, sort_decls, sort)) - .try_collect_exhaust()?; - let output = resolve_base_sort(sess, sort_decls, output)?; - Ok(fhir::PolyFuncSort::new(0, inputs, output)) -} +impl<'a> SortResolver<'a> { + pub fn with_sort_params( + sess: &'a FluxSession, + sort_decls: &'a fhir::SortDecls, + sort_params: &[Symbol], + ) -> Self { + let sort_params = sort_params + .iter() + .enumerate() + .map(|(i, v)| (*v, i)) + .collect(); + Self { sess, sort_decls, generic_params: Default::default(), sort_params } + } -fn resolve_base_sort( - sess: &FluxSession, - sort_decls: &fhir::SortDecls, - base: &surface::BaseSort, -) -> Result { - match base { - surface::BaseSort::Ident(ident) => resolve_base_sort_ident(sess, sort_decls, ident), - surface::BaseSort::BitVec(w) => Ok(fhir::Sort::BitVec(*w)), - surface::BaseSort::App(ident, args) => resolve_app_sort(sess, sort_decls, *ident, args), + pub fn with_generics( + sess: &'a FluxSession, + sort_decls: &'a fhir::SortDecls, + generics: &'a Generics, + ) -> Self { + let generic_params = generics.params.iter().map(|p| (p.name, p.def_id)).collect(); + Self { sess, sort_decls, sort_params: Default::default(), generic_params } } -} -fn resolve_sort_ctor(sess: &FluxSession, ident: surface::Ident) -> Result { - if ident.name == SORTS.set { - Ok(fhir::SortCtor::Set) - } else if ident.name == SORTS.map { - Ok(fhir::SortCtor::Map) - } else { - Err(sess.emit_err(errors::UnresolvedSort::new(ident))) + fn resolve_sort(&self, sort: &surface::Sort) -> Result { + match sort { + surface::Sort::Base(sort) => self.resolve_base_sort(sort), + surface::Sort::Func { inputs, output } => { + Ok(self.resolve_func_sort(inputs, output)?.into()) + } + surface::Sort::Infer => Ok(fhir::Sort::Wildcard), + } } -} -fn resolve_app_sort( - sess: &FluxSession, - sort_decls: &fhir::SortDecls, - ident: surface::Ident, - args: &Vec, -) -> Result { - let ctor = resolve_sort_ctor(sess, ident)?; - let arity = ctor.arity(); - if args.len() == arity { - let args = args + fn resolve_func_sort( + &self, + inputs: &[surface::BaseSort], + output: &surface::BaseSort, + ) -> Result { + let inputs: Vec = inputs .iter() - .map(|arg| resolve_base_sort(sess, sort_decls, arg)) + .map(|sort| self.resolve_base_sort(sort)) .try_collect_exhaust()?; - Ok(fhir::Sort::App(ctor, args)) - } else { - Err(sess.emit_err(errors::SortArityMismatch::new(ident.span, arity, args.len()))) + let output = self.resolve_base_sort(output)?; + Ok(fhir::PolyFuncSort::new(0, inputs, output)) } -} -fn resolve_base_sort_ident( - sess: &FluxSession, - sort_decls: &fhir::SortDecls, - ident: &surface::Ident, -) -> Result { - if ident.name == SORTS.int { - Ok(fhir::Sort::Int) - } else if ident.name == sym::bool { - Ok(fhir::Sort::Bool) - } else if ident.name == SORTS.real { - Ok(fhir::Sort::Real) - } else if sort_decls.get(&ident.name).is_some() { - let ctor = fhir::SortCtor::User { name: ident.name, arity: 0 }; - Ok(fhir::Sort::App(ctor, List::empty())) - } else { - Err(sess.emit_err(errors::UnresolvedSort::new(*ident))) + fn resolve_base_sort(&self, base: &surface::BaseSort) -> Result { + match base { + surface::BaseSort::Ident(ident) => self.resolve_base_sort_ident(ident), + surface::BaseSort::BitVec(w) => Ok(fhir::Sort::BitVec(*w)), + surface::BaseSort::App(ident, args) => self.resolve_app_sort(*ident, args), + } + } + + fn resolve_sort_ctor(&self, ident: surface::Ident) -> Result { + if ident.name == SORTS.set { + Ok(fhir::SortCtor::Set) + } else if ident.name == SORTS.map { + Ok(fhir::SortCtor::Map) + } else { + Err(self.sess.emit_err(errors::UnresolvedSort::new(ident))) + } + } + + fn resolve_app_sort( + &self, + ident: surface::Ident, + args: &Vec, + ) -> Result { + let ctor = self.resolve_sort_ctor(ident)?; + let arity = ctor.arity(); + if args.len() == arity { + let args = args + .iter() + .map(|arg| self.resolve_base_sort(arg)) + .try_collect_exhaust()?; + Ok(fhir::Sort::App(ctor, args)) + } else { + Err(self + .sess + .emit_err(errors::SortArityMismatch::new(ident.span, arity, args.len()))) + } + } + + fn resolve_base_sort_ident(&self, ident: &surface::Ident) -> Result { + if ident.name == SORTS.int { + Ok(fhir::Sort::Int) + } else if ident.name == sym::bool { + Ok(fhir::Sort::Bool) + } else if ident.name == SORTS.real { + Ok(fhir::Sort::Real) + } else if let Some(def_id) = self.generic_params.get(&ident.name) { + Ok(fhir::Sort::Param(*def_id)) + } else if let Some(idx) = self.sort_params.get(&ident.name) { + Ok(fhir::Sort::Var(*idx)) + } else if self.sort_decls.get(&ident.name).is_some() { + let ctor = fhir::SortCtor::User { name: ident.name, arity: 0 }; + Ok(fhir::Sort::App(ctor, List::empty())) + } else { + Err(self.sess.emit_err(errors::UnresolvedSort::new(*ident))) + } } } @@ -1397,28 +1521,27 @@ impl Binders { fn from_params<'a>( genv: &GlobalEnv, + sort_resolver: &SortResolver, params: impl IntoIterator, ) -> Result { let mut binders = Self::new(); binders.push_layer(); - binders.insert_params(genv, params)?; + binders.insert_params(genv, sort_resolver, params)?; Ok(binders) } fn insert_params<'a>( &mut self, genv: &GlobalEnv, + sort_resolver: &SortResolver, params: impl IntoIterator, ) -> Result { + // let sr = SortResolver::with_sort_params(genv.sess, genv.map().sort_decls(), sort_params); for param in params { self.insert_binder( genv.sess, param.name, - Binder::Refined( - self.fresh(), - resolve_sort(genv.sess, genv.map().sort_decls(), ¶m.sort)?, - false, - ), + Binder::Refined(self.fresh(), sort_resolver.resolve_sort(¶m.sort)?, false), )?; } Ok(()) @@ -1465,14 +1588,6 @@ impl Binders { Binder::Refined(self.fresh(), sort, true) } - fn binder_from_res(&self, genv: &GlobalEnv, res: fhir::Res) -> Binder { - if let Some(sort) = genv.sort_of_res(res) { - self.binder_from_sort(sort) - } else { - Binder::Unrefined - } - } - fn binder_from_bty( &self, genv: &GlobalEnv, @@ -1581,21 +1696,65 @@ fn index_sort( resolver_output: &ResolverOutput, bty: &surface::BaseTy, ) -> Option { - // CODESYNC(sort-of, 4) sorts should be given consistently + // CODESYNC(sort-of, 3) sorts should be given consistently match &bty.kind { - surface::BaseTyKind::Path(path) => { - let res = resolver_output.path_res_map[&path.node_id]; - genv.sort_of_res(res) - } + surface::BaseTyKind::Path(path) => sort_of_surface_path(genv, resolver_output, path), surface::BaseTyKind::Slice(_) => Some(fhir::Sort::Int), } } -fn as_tuple<'a>(genv: &'a GlobalEnv, sort: &'a fhir::Sort) -> &'a [fhir::Sort] { - if let fhir::Sort::Record(def_id) = sort { - genv.index_sorts_of(*def_id) +fn sort_of_surface_path( + genv: &GlobalEnv, + resolver_output: &ResolverOutput, + path: &surface::Path, +) -> Option { + // CODESYNC(sort-of-path, 2) sorts should be given consistently + let res = resolver_output.path_res_map[&path.node_id]; + + match res { + fhir::Res::PrimTy(PrimTy::Int(_) | PrimTy::Uint(_)) => Some(fhir::Sort::Int), + fhir::Res::PrimTy(PrimTy::Bool) => Some(fhir::Sort::Bool), + fhir::Res::PrimTy(PrimTy::Float(..) | PrimTy::Str | PrimTy::Char) => Some(fhir::Sort::Unit), + fhir::Res::Def(DefKind::TyAlias { .. } | DefKind::Enum | DefKind::Struct, def_id) => { + // TODO: duplication with sort_of_path + let mut sort_args = vec![]; + if let Ok(generics) = genv.generics_of(def_id) { + for (param, arg) in generics.params.iter().zip(&path.generics) { + if let rty::GenericParamDefKind::SplTy = param.kind { + let surface::GenericArg::Type(ty) = arg else { return None }; + let surface::BaseTyKind::Path(path) = &ty.as_bty()?.kind else { + return None; + }; + let sort = sort_of_surface_path(genv, resolver_output, path)?; + sort_args.push(sort); + } + } + }; + Some(fhir::Sort::Record(def_id, List::from_vec(sort_args))) + } + fhir::Res::Def(DefKind::TyParam, def_id) => { + let param = genv.get_generic_param(def_id.expect_local()); + match ¶m.kind { + fhir::GenericParamKind::BaseTy => Some(fhir::Sort::Param(def_id)), + fhir::GenericParamKind::Type { .. } + | fhir::GenericParamKind::Lifetime + | fhir::GenericParamKind::SplTy => None, + } + } + + fhir::Res::Def(DefKind::AssocTy | DefKind::OpaqueTy, _) | fhir::Res::SelfTyParam { .. } => { + None + } + fhir::Res::SelfTyAlias { alias_to, .. } => genv.sort_of_self_ty_alias(alias_to), + + fhir::Res::Def(..) => bug!("unexpected res {res:?}"), + } +} +fn as_tuple<'a>(genv: &'a GlobalEnv, sort: &'a fhir::Sort) -> Vec { + if let fhir::Sort::Record(def_id, sort_args) = sort { + genv.index_sorts_of(*def_id, sort_args) } else { - slice::from_ref(sort) + vec![sort.clone()] } } diff --git a/crates/flux-desugar/src/lib.rs b/crates/flux-desugar/src/lib.rs index 864abf1ac6..8ea41caa70 100644 --- a/crates/flux-desugar/src/lib.rs +++ b/crates/flux-desugar/src/lib.rs @@ -1,6 +1,6 @@ //! Desugaring from types in [`flux_syntax::surface`] to types in [`flux_middle::fhir`] //! -//! # NOTE +//! # Generics and Desugaring //! //! Desugaring requires knowing the sort of each type so we can correctly resolve binders declared with //! @ syntax or arg syntax. In particular, to know the sort of a type parameter we need to know its @@ -53,10 +53,9 @@ pub fn desugar_struct_def( let mut cx = DesugarCtxt::new(genv, owner_id, resolver_output, None); // Desugar and insert generics - let (generics, predicates) = cx.as_lift_cx().lift_generics_with_predicates()?; - genv.map().insert_generics(def_id, generics); + let predicates = cx.as_lift_cx().lift_predicates()?; - // Desugar of struct_def needs to happen AFTER inserting generics. See crate level comment + // Desugar of struct_def needs to happen AFTER inserting generics. See #generics-and-desugaring let struct_def = cx.desugar_struct_def(struct_def, &mut Binders::new())?; if config::dump_fhir() { dbg::dump_item_info(genv.tcx, owner_id, "fhir", &struct_def).unwrap(); @@ -173,15 +172,25 @@ pub fn desugar_fn_sig( } /// HACK(nilehmann) this is a bit of a hack. We use it to properly register generics and predicates -/// for items that don't have surface syntax (impl blocks, traits, ...). In this cases we just [lift] -/// them from hir. +/// for items that don't have surface syntax (impl blocks, traits, ...), or for `impl` blocks with +/// explicit `generics` annotations. In the former case, we use `desugar`; in the latter cases we +/// just [lift] them from hir. pub fn desugar_generics_and_predicates( genv: &mut GlobalEnv, owner_id: OwnerId, + resolver_output: &ResolverOutput, + generics: Option<&surface::Generics>, ) -> Result<(), ErrorGuaranteed> { - let def_id = owner_id.def_id; - let (generics, predicates) = + let (lifted_generics, predicates) = LiftCtxt::new(genv.tcx, genv.sess, owner_id, None).lift_generics_with_predicates()?; + + let generics = if let Some(generics) = generics { + let cx = DesugarCtxt::new(genv, owner_id, resolver_output, None); + cx.desugar_generics(generics)? + } else { + lifted_generics + }; + let def_id = owner_id.def_id; genv.map().insert_generics(def_id, generics); genv.map_mut().insert_generic_predicates(def_id, predicates); Ok(()) diff --git a/crates/flux-driver/src/callbacks.rs b/crates/flux-driver/src/callbacks.rs index 009cb01dd8..c90405f6bb 100644 --- a/crates/flux-driver/src/callbacks.rs +++ b/crates/flux-driver/src/callbacks.rs @@ -153,15 +153,18 @@ fn stage1_desugar(genv: &mut GlobalEnv, specs: &Specs) -> Result<(), ErrorGuaran .err() .or(err); - // Register RefinedBys + // Register RefinedBys (for structs and enums, which also registers their Generics) err = specs .refined_bys() .try_for_each_exhaust(|(owner_id, refined_by)| { + let generics = lift::lift_generics(tcx, sess, owner_id)?; let refined_by = if let Some(refined_by) = refined_by { - desugar::desugar_refined_by(sess, map.sort_decls(), owner_id, refined_by)? + let generics = tcx.generics_of(owner_id); + desugar::desugar_refined_by(sess, map.sort_decls(), owner_id, generics, refined_by)? } else { lift::lift_refined_by(tcx, owner_id) }; + map.insert_generics(owner_id.def_id, generics.with_refined_by(&refined_by)); map.insert_refined_by(owner_id.def_id, refined_by); Ok(()) }) @@ -273,7 +276,9 @@ fn desugar_item( let ty_alias = specs.ty_aliases[&owner_id].as_ref(); desugar::desugar_type_alias(genv, owner_id, ty_alias, resolver_output)?; } - hir::ItemKind::OpaqueTy(_) => desugar::desugar_generics_and_predicates(genv, owner_id)?, + hir::ItemKind::OpaqueTy(_) => { + desugar::desugar_generics_and_predicates(genv, owner_id, resolver_output, None)?; + } hir::ItemKind::Enum(..) => { let enum_def = &specs.enums[&owner_id]; desugar::desugar_enum_def(genv, owner_id, enum_def, resolver_output)?; @@ -283,7 +288,7 @@ fn desugar_item( desugar::desugar_struct_def(genv, owner_id, struct_def, resolver_output)?; } hir::ItemKind::Trait(.., items) => { - desugar::desugar_generics_and_predicates(genv, owner_id)?; + desugar::desugar_generics_and_predicates(genv, owner_id, resolver_output, None)?; items.iter().try_for_each_exhaust(|trait_item| { desugar_assoc_item( genv, @@ -295,7 +300,8 @@ fn desugar_item( })?; } hir::ItemKind::Impl(impl_) => { - desugar::desugar_generics_and_predicates(genv, owner_id)?; + let generics = specs.impls.get(&owner_id); + desugar::desugar_generics_and_predicates(genv, owner_id, resolver_output, generics)?; impl_.items.iter().try_for_each_exhaust(|impl_item| { desugar_assoc_item( genv, @@ -316,11 +322,13 @@ fn desugar_assoc_item( specs: &mut Specs, owner_id: OwnerId, kind: hir::AssocItemKind, - resolver_outpt: &ResolverOutput, + resolver_output: &ResolverOutput, ) -> Result<(), ErrorGuaranteed> { match kind { - hir::AssocItemKind::Fn { .. } => desugar_fn_sig(genv, specs, owner_id, resolver_outpt), - hir::AssocItemKind::Type => desugar::desugar_generics_and_predicates(genv, owner_id), + hir::AssocItemKind::Fn { .. } => desugar_fn_sig(genv, specs, owner_id, resolver_output), + hir::AssocItemKind::Type => { + desugar::desugar_generics_and_predicates(genv, owner_id, resolver_output, None) + } hir::AssocItemKind::Const => Ok(()), } } diff --git a/crates/flux-driver/src/collector.rs b/crates/flux-driver/src/collector.rs index 29768cd629..fa81064adb 100644 --- a/crates/flux-driver/src/collector.rs +++ b/crates/flux-driver/src/collector.rs @@ -41,6 +41,7 @@ pub type Ignores = UnordSet; pub(crate) struct Specs { pub fn_sigs: UnordMap, pub structs: FxHashMap, + pub impls: FxHashMap, pub enums: FxHashMap, pub qualifs: Vec, pub func_defs: Vec, @@ -101,6 +102,7 @@ impl<'tcx, 'a> SpecCollector<'tcx, 'a> { ItemKind::Mod(..) => collector.parse_mod_spec(owner_id.def_id, attrs), ItemKind::TyAlias(..) => collector.parse_tyalias_spec(owner_id, attrs), ItemKind::Const(..) => collector.parse_const_spec(item, attrs), + ItemKind::Impl(_) => collector.parse_impl_spec(owner_id, attrs), _ => Ok(()), }; } @@ -185,6 +187,21 @@ impl<'tcx, 'a> SpecCollector<'tcx, 'a> { } } + fn parse_impl_spec( + &mut self, + owner_id: OwnerId, + attrs: &[Attribute], + ) -> Result<(), ErrorGuaranteed> { + let mut attrs = self.parse_flux_attrs(attrs)?; + self.report_dups(&attrs)?; + + if let Some(generics) = attrs.generics() { + self.specs.impls.insert(owner_id, generics); + } + + Ok(()) + } + fn parse_tyalias_spec( &mut self, owner_id: OwnerId, @@ -209,6 +226,8 @@ impl<'tcx, 'a> SpecCollector<'tcx, 'a> { let refined_by = attrs.refined_by(); + let generics = attrs.generics(); + let fields = data .fields() .iter() @@ -228,9 +247,10 @@ impl<'tcx, 'a> SpecCollector<'tcx, 'a> { .insert(extern_def_id, owner_id.def_id); } - self.specs - .structs - .insert(owner_id, surface::StructDef { refined_by, fields, opaque, invariants }); + self.specs.structs.insert( + owner_id, + surface::StructDef { refined_by, generics, fields, opaque, invariants }, + ); Ok(()) } @@ -361,6 +381,9 @@ impl<'tcx, 'a> SpecCollector<'tcx, 'a> { ("refined_by", AttrArgs::Delimited(dargs)) => { self.parse(dargs, ParseSess::parse_refined_by, FluxAttrKind::RefinedBy)? } + ("generics", AttrArgs::Delimited(dargs)) => { + self.parse(dargs, ParseSess::parse_generics, FluxAttrKind::Generics)? + } ("field", AttrArgs::Delimited(dargs)) => { self.parse(dargs, ParseSess::parse_type, FluxAttrKind::Field)? } @@ -489,6 +512,7 @@ impl Specs { fn new() -> Specs { Specs { fn_sigs: Default::default(), + impls: Default::default(), structs: Default::default(), enums: Default::default(), qualifs: Vec::default(), @@ -546,6 +570,7 @@ enum FluxAttrKind { Opaque, FnSig(surface::FnSig), RefinedBy(surface::RefinedBy), + Generics(surface::Generics), QualNames(surface::QualNames), Items(Vec), TypeAlias(surface::TyAlias), @@ -636,6 +661,10 @@ impl FluxAttrs { read_attr!(self, RefinedBy) } + fn generics(&mut self) -> Option { + read_attr!(self, Generics) + } + fn field(&mut self) -> Option { read_attr!(self, Field) } @@ -674,6 +703,7 @@ impl FluxAttrKind { FluxAttrKind::FnSig(_) => attr_name!(FnSig), FluxAttrKind::ConstSig(_) => attr_name!(ConstSig), FluxAttrKind::RefinedBy(_) => attr_name!(RefinedBy), + FluxAttrKind::Generics(_) => attr_name!(Generics), FluxAttrKind::Items(_) => attr_name!(Items), FluxAttrKind::QualNames(_) => attr_name!(QualNames), FluxAttrKind::Field(_) => attr_name!(Field), diff --git a/crates/flux-fhir-analysis/locales/en-US.ftl b/crates/flux-fhir-analysis/locales/en-US.ftl index b475be788a..4ca0962e42 100644 --- a/crates/flux-fhir-analysis/locales/en-US.ftl +++ b/crates/flux-fhir-analysis/locales/en-US.ftl @@ -80,6 +80,9 @@ fhir_analysis_expected_numeric = fhir_analysis_no_equality = values of sort `{$sort}` cannot be compared for equality +fhir_analysis_invalid_base_instance = + values of this type cannot be used as base sorted instances + fhir_analysis_param_not_determined = parameter `{$sym}` cannot be determined .label = undetermined parameter @@ -153,4 +156,3 @@ fhir_analysis_assoc_type_not_found = associated type not found .label = cannot resolve this associated type .note = flux cannot resolved associated types if they are defined in a super trait - diff --git a/crates/flux-fhir-analysis/src/conv.rs b/crates/flux-fhir-analysis/src/conv.rs index f63b7ebe1a..bdab838aec 100644 --- a/crates/flux-fhir-analysis/src/conv.rs +++ b/crates/flux-fhir-analysis/src/conv.rs @@ -162,6 +162,7 @@ pub(crate) fn conv_generics( fhir::GenericParamKind::Type { default } => { rty::GenericParamDefKind::Type { has_default: default.is_some() } } + fhir::GenericParamKind::SplTy => rty::GenericParamDefKind::SplTy, fhir::GenericParamKind::BaseTy => rty::GenericParamDefKind::BaseTy, fhir::GenericParamKind::Lifetime => rty::GenericParamDefKind::Lifetime, }; @@ -193,12 +194,24 @@ pub(crate) fn conv_generics( }) } +fn sort_args_for_adt(genv: &GlobalEnv, def_id: impl Into) -> List { + let mut sort_args = vec![]; + for param in &genv.tcx.generics_of(def_id.into()).params { + if let rustc_middle::ty::GenericParamDefKind::Type { .. } = param.kind { + sort_args.push(fhir::Sort::Param(param.def_id)); + } + } + List::from_vec(sort_args) +} + pub(crate) fn adt_def_for_struct( genv: &GlobalEnv, invariants: Vec, struct_def: &fhir::StructDef, ) -> rty::AdtDef { - let sort = rty::Sort::tuple(conv_sorts(genv, genv.index_sorts_of(struct_def.owner_id))); + let def_id = struct_def.owner_id; + let sort_args = sort_args_for_adt(genv, def_id); + let sort = rty::Sort::tuple(conv_sorts(genv, &genv.index_sorts_of(def_id, &sort_args))); let adt_def = lowering::lower_adt_def(&genv.tcx.adt_def(struct_def.owner_id)); rty::AdtDef::new(adt_def, sort, invariants, struct_def.is_opaque()) } @@ -208,7 +221,9 @@ pub(crate) fn adt_def_for_enum( invariants: Vec, enum_def: &fhir::EnumDef, ) -> rty::AdtDef { - let sort = rty::Sort::tuple(conv_sorts(genv, genv.index_sorts_of(enum_def.owner_id))); + let def_id = enum_def.owner_id; + let sort_args = sort_args_for_adt(genv, def_id); + let sort = rty::Sort::tuple(conv_sorts(genv, &genv.index_sorts_of(def_id, &sort_args))); let adt_def = lowering::lower_adt_def(&genv.tcx.adt_def(enum_def.owner_id)); rty::AdtDef::new(adt_def, sort, invariants, false) } @@ -381,8 +396,9 @@ impl<'a, 'tcx> ConvCtxt<'a, 'tcx> { let kind = rty::BoundRegionKind::BrNamed(def_id.to_def_id(), name); Ok(rty::BoundVariableKind::Region(kind)) } - fhir::GenericParamKind::Type { default: _ } => bug!("unexpected!"), - fhir::GenericParamKind::BaseTy => bug!("unexpected!"), + fhir::GenericParamKind::Type { default: _ } + | fhir::GenericParamKind::BaseTy + | fhir::GenericParamKind::SplTy => bug!("unexpected!"), } } @@ -640,7 +656,6 @@ impl<'a, 'tcx> ConvCtxt<'a, 'tcx> { return Ok(rty::Ty::param(param_ty)); } } - let sort = conv_sort(self.genv, &sort.unwrap()); if sort.is_unit() { let idx = rty::Index::from(rty::Expr::unit()); @@ -710,7 +725,7 @@ impl<'a, 'tcx> ConvCtxt<'a, 'tcx> { let expr = self.add_coercions(rty::Expr::abs(body), *fhir_id); (expr, rty::TupleTree::Leaf(false)) } - fhir::RefineArg::Record(_, flds, ..) => { + fhir::RefineArg::Record(_, _, flds, ..) => { let mut exprs = vec![]; let mut is_binder = vec![]; for arg in flds { @@ -1087,10 +1102,12 @@ impl LookupResult<'_> { fn is_record(&self) -> Option { match &self.kind { LookupResultKind::LateBoundList { - entry: Entry::Sort { sort: fhir::Sort::Record(def_id), .. }, + entry: Entry::Sort { sort: fhir::Sort::Record(def_id, _), .. }, .. } => Some(*def_id), - LookupResultKind::EarlyBound { sort: fhir::Sort::Record(def_id), .. } => Some(*def_id), + LookupResultKind::EarlyBound { sort: fhir::Sort::Record(def_id, _), .. } => { + Some(*def_id) + } _ => None, } } @@ -1142,8 +1159,8 @@ pub fn conv_sort(genv: &GlobalEnv, sort: &fhir::Sort) -> rty::Sort { fhir::Sort::Loc => rty::Sort::Loc, fhir::Sort::Unit => rty::Sort::unit(), fhir::Sort::Func(fsort) => rty::Sort::Func(conv_func_sort(genv, fsort)), - fhir::Sort::Record(def_id) => { - rty::Sort::tuple(conv_sorts(genv, genv.index_sorts_of(*def_id))) + fhir::Sort::Record(def_id, sort_args) => { + rty::Sort::tuple(conv_sorts(genv, &genv.index_sorts_of(*def_id, sort_args))) } fhir::Sort::App(ctor, args) => { let ctor = conv_sort_ctor(ctor); diff --git a/crates/flux-fhir-analysis/src/lib.rs b/crates/flux-fhir-analysis/src/lib.rs index e808d6e6af..95814ecf14 100644 --- a/crates/flux-fhir-analysis/src/lib.rs +++ b/crates/flux-fhir-analysis/src/lib.rs @@ -294,7 +294,6 @@ fn check_wf_rust_item(genv: &GlobalEnv, def_id: LocalDefId) -> QueryResult { let owner_id = OwnerId { def_id }; - let fn_sig = genv.map().get_fn_sig(def_id); let mut wfckresults = wf::check_fn_sig(genv, fn_sig, owner_id)?; annot_check::check_fn_sig(genv.tcx, genv.sess, &mut wfckresults, owner_id, fn_sig)?; diff --git a/crates/flux-fhir-analysis/src/wf/errors.rs b/crates/flux-fhir-analysis/src/wf/errors.rs index 1ee9abac52..49c551eae9 100644 --- a/crates/flux-fhir-analysis/src/wf/errors.rs +++ b/crates/flux-fhir-analysis/src/wf/errors.rs @@ -196,6 +196,20 @@ impl<'a> InvalidPrimitiveDotAccess<'a> { } } +#[derive(Diagnostic)] +#[diag(fhir_analysis_invalid_base_instance, code = "FLUX")] +pub(super) struct InvalidBaseInstance<'a> { + #[primary_span] + span: Span, + ty: &'a fhir::Ty, +} + +impl<'a> InvalidBaseInstance<'a> { + pub(super) fn new(ty: &'a fhir::Ty) -> Self { + Self { ty, span: ty.span } + } +} + #[derive(Diagnostic)] #[diag(fhir_analysis_no_equality, code = "FLUX")] pub(super) struct NoEquality<'a> { diff --git a/crates/flux-fhir-analysis/src/wf/mod.rs b/crates/flux-fhir-analysis/src/wf/mod.rs index 2c654d8998..050f5ae50c 100644 --- a/crates/flux-fhir-analysis/src/wf/mod.rs +++ b/crates/flux-fhir-analysis/src/wf/mod.rs @@ -7,11 +7,12 @@ mod sortck; use std::iter; -use flux_common::{iter::IterExt, span_bug}; -use flux_errors::FluxSession; +use flux_common::{bug, iter::IterExt, span_bug}; +use flux_errors::{FluxSession, ResultExt}; use flux_middle::{ fhir::{self, FluxOwnerId, SurfaceIdent, WfckResults}, global_env::GlobalEnv, + rty::GenericParamDefKind, }; use rustc_data_structures::{ snapshot_map::{self, SnapshotMap}, @@ -19,7 +20,7 @@ use rustc_data_structures::{ }; use rustc_errors::{ErrorGuaranteed, IntoDiagnostic}; use rustc_hash::FxHashSet; -use rustc_hir::{def::DefKind, OwnerId}; +use rustc_hir::{def::DefKind, def_id::DefId, OwnerId}; use rustc_span::Symbol; use self::sortck::InferCtxt; @@ -233,8 +234,9 @@ impl<'a, 'tcx> Wf<'a, 'tcx> { ) -> Result<(), ErrorGuaranteed> { match bound { fhir::GenericBound::Trait(trait_ref, _) => self.check_path(infcx, &trait_ref.trait_ref), - fhir::GenericBound::LangItemTrait(_, args, bindings) => { - self.check_generic_args(infcx, args)?; + fhir::GenericBound::LangItemTrait(lang_item, args, bindings) => { + let def_id = self.genv.tcx.require_lang_item(*lang_item, None); + self.check_generic_args(infcx, def_id, args)?; self.check_type_bindings(infcx, bindings)?; Ok(()) } @@ -360,9 +362,10 @@ impl<'a, 'tcx> Wf<'a, 'tcx> { self.check_type(infcx, ty)?; self.check_pred(infcx, pred) } - fhir::TyKind::OpaqueDef(_, args, _refine_args, _) => { + fhir::TyKind::OpaqueDef(item_id, args, _refine_args, _) => { // TODO sanity check the _refine_args (though they should never fail!) but we'd need their expected sorts - self.check_generic_args(infcx, args) + let def_id = item_id.owner_id.to_def_id(); + self.check_generic_args(infcx, def_id, args) } fhir::TyKind::RawPtr(ty, _) => self.check_type(infcx, ty), fhir::TyKind::Hole(_) | fhir::TyKind::Never => Ok(()), @@ -382,11 +385,52 @@ impl<'a, 'tcx> Wf<'a, 'tcx> { .try_collect_exhaust() } + fn check_ty_is_base(&self, ty: &fhir::Ty) -> Result<(), ErrorGuaranteed> { + match &ty.kind { + fhir::TyKind::BaseTy(_) | fhir::TyKind::Indexed(_, _) => Ok(()), + fhir::TyKind::Tuple(tys) => { + for ty in tys { + self.check_ty_is_base(ty)?; + } + Ok(()) + } + fhir::TyKind::Constr(_, ty) | fhir::TyKind::Exists(_, ty) => self.check_ty_is_base(ty), + + fhir::TyKind::Ptr(_, _) + | fhir::TyKind::Ref(_, _) + | fhir::TyKind::Array(_, _) + | fhir::TyKind::RawPtr(_, _) + | fhir::TyKind::OpaqueDef(_, _, _, _) + | fhir::TyKind::Never + | fhir::TyKind::Hole(_) => self.emit_err(errors::InvalidBaseInstance::new(ty)), + } + } + + fn check_generic_args_kinds( + &self, + def_id: DefId, + args: &[fhir::GenericArg], + ) -> Result<(), ErrorGuaranteed> { + let generics = self.genv.generics_of(def_id).emit(self.genv.sess)?; + for (arg, param) in iter::zip(args, &generics.params) { + if param.kind == GenericParamDefKind::SplTy { + if let fhir::GenericArg::Type(ty) = arg { + self.check_ty_is_base(ty)?; + } else { + bug!("expected type argument got `{arg:?}`"); + } + } + } + Ok(()) + } + fn check_generic_args( &mut self, infcx: &mut InferCtxt, + def_id: DefId, args: &[fhir::GenericArg], ) -> Result<(), ErrorGuaranteed> { + self.check_generic_args_kinds(def_id, args)?; args.iter() .try_for_each_exhaust(|arg| self.check_generic_arg(infcx, arg)) } @@ -436,7 +480,7 @@ impl<'a, 'tcx> Wf<'a, 'tcx> { ) -> Result<(), ErrorGuaranteed> { match &path.res { fhir::Res::Def(DefKind::TyAlias { .. }, def_id) => { - let sorts = self.genv.early_bound_sorts_of(*def_id); + let sorts = self.genv.early_bound_sorts_of(*def_id, &[]); if path.refine.len() != sorts.len() { return self.emit_err(errors::EarlyBoundArgCountMismatch::new( path.span, @@ -445,7 +489,7 @@ impl<'a, 'tcx> Wf<'a, 'tcx> { )); } iter::zip(&path.refine, sorts) - .try_for_each_exhaust(|(arg, sort)| self.check_refine_arg(infcx, arg, sort))?; + .try_for_each_exhaust(|(arg, sort)| self.check_refine_arg(infcx, arg, &sort))?; } fhir::Res::SelfTyParam { .. } | fhir::Res::SelfTyAlias { .. } @@ -453,12 +497,14 @@ impl<'a, 'tcx> Wf<'a, 'tcx> { | fhir::Res::PrimTy(..) => {} } let snapshot = self.xi.snapshot(); - let args = self.check_generic_args(infcx, &path.args); + + if let fhir::Res::Def(_kind, did) = &path.res /*&& !matches!(_kind, DefKind::TyParam) */ && !path.args.is_empty() { + self.check_generic_args(infcx, *did, &path.args)?; + } let bindings = self.check_type_bindings(infcx, &path.bindings); if !self.genv.is_box(path.res) { self.xi.rollback_to(snapshot); } - args?; bindings?; Ok(()) } @@ -498,7 +544,7 @@ impl<'a, 'tcx> Wf<'a, 'tcx> { Ok(()) } fhir::RefineArg::Abs(_, body, ..) => self.check_param_uses_expr(infcx, body, true), - fhir::RefineArg::Record(_, flds, ..) => { + fhir::RefineArg::Record(_, _, flds, ..) => { flds.iter() .try_for_each_exhaust(|arg| self.check_param_uses_refine_arg(infcx, arg)) } diff --git a/crates/flux-fhir-analysis/src/wf/sortck.rs b/crates/flux-fhir-analysis/src/wf/sortck.rs index a3a21f4dae..a776abcb7a 100644 --- a/crates/flux-fhir-analysis/src/wf/sortck.rs +++ b/crates/flux-fhir-analysis/src/wf/sortck.rs @@ -16,7 +16,7 @@ use rustc_span::Span; use super::errors; pub(super) struct InferCtxt<'a, 'tcx> { - genv: &'a GlobalEnv<'a, 'tcx>, + pub genv: &'a GlobalEnv<'a, 'tcx>, sorts: UnordMap, unification_table: InPlaceUnificationTable, wfckresults: fhir::WfckResults, @@ -45,9 +45,9 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { fhir::RefineArg::Abs(params, body, span, fhir_id) => { self.check_abs(params, body, span, fhir_id, expected) } - fhir::RefineArg::Record(def_id, flds, span) => { - self.check_record(*def_id, flds, *span)?; - let found = fhir::Sort::Record(*def_id); + fhir::RefineArg::Record(def_id, sort_args, flds, span) => { + self.check_record(*def_id, sort_args, flds, *span)?; + let found = fhir::Sort::Record(*def_id, sort_args.clone()); if &found != expected { return Err(self.emit_sort_mismatch(*span, expected, &found)); } @@ -92,10 +92,11 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { fn check_record( &mut self, def_id: DefId, + sort_args: &[fhir::Sort], args: &[fhir::RefineArg], span: Span, ) -> Result<(), ErrorGuaranteed> { - let sorts = self.genv.index_sorts_of(def_id); + let sorts = self.genv.index_sorts_of(def_id, sort_args); if args.len() != sorts.len() { return Err(self.emit_err(errors::ArgCountMismatch::new( Some(span), @@ -105,7 +106,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { ))); } izip!(args, sorts) - .map(|(arg, expected)| self.check_refine_arg(arg, expected)) + .map(|(arg, expected)| self.check_refine_arg(arg, &expected)) .try_collect_exhaust() } @@ -207,11 +208,10 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { } fhir::ExprKind::Dot(var, fld) => { let sort = self[var.name].clone(); - match sort { - fhir::Sort::Record(def_id) => { + match &sort { + fhir::Sort::Record(def_id, sort_args) => { self.genv - .field_sort(def_id, fld.name) - .cloned() + .field_sort(*def_id, sort_args.clone(), fld.name) .ok_or_else(|| self.emit_field_not_found(&sort, *fld)) } fhir::Sort::Bool | fhir::Sort::Int | fhir::Sort::Real => { @@ -519,10 +519,12 @@ impl<'a> InferCtxt<'a, '_> { }) } - fn is_single_field_record(&mut self, sort: &fhir::Sort) -> Option<&'a fhir::Sort> { + fn is_single_field_record(&mut self, sort: &fhir::Sort) -> Option { self.resolve_sort(sort).and_then(|s| { - if let fhir::Sort::Record(def_id) = s && let [sort] = self.genv.index_sorts_of(def_id) { - Some(sort) + if let fhir::Sort::Record(def_id, sort_args) = s && + let [sort] = &self.genv.index_sorts_of(def_id, &sort_args)[..] + { + Some(sort.clone()) } else { None } diff --git a/crates/flux-middle/src/fhir.rs b/crates/flux-middle/src/fhir.rs index 402b264dbf..a0aa909f02 100644 --- a/crates/flux-middle/src/fhir.rs +++ b/crates/flux-middle/src/fhir.rs @@ -61,9 +61,10 @@ pub struct GenericParam { pub kind: GenericParamKind, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum GenericParamKind { Type { default: Option }, + SplTy, BaseTy, Lifetime, } @@ -392,7 +393,7 @@ pub enum RefineArg { is_binder: bool, }, Abs(Vec, Expr, Span, FhirId), - Record(DefId, Vec, Span), + Record(DefId, List, Vec, Span), } /// These are types of things that may be refined with indices or existentials @@ -432,7 +433,6 @@ pub struct TypeBinding { pub enum GenericArg { Lifetime(Lifetime), Type(Ty), - // Constraint(SurfaceIdent, Ty), } #[derive(Eq, PartialEq, Debug, Copy, Clone)] @@ -497,7 +497,8 @@ pub enum Sort { Var(usize), /// A record sort corresponds to the sort associated with a type alias or an adt (struct/enum). /// Values of a record sort can be projected using dot notation to extract their fields. - Record(DefId), + /// the List is for the type parameters of (generic) record sorts + Record(DefId, List), /// The sort associated to a type variable Param(DefId), /// A sort that needs to be inferred @@ -631,9 +632,7 @@ impl SortCtor { impl Ty { pub fn as_path(&self) -> Option<&Path> { match &self.kind { - TyKind::BaseTy(BaseTy { - kind: BaseTyKind::Path(QPath::Resolved(None, path)), .. - }) => Some(path), + TyKind::BaseTy(bty) => bty.as_path(), _ => None, } } @@ -646,6 +645,13 @@ impl BaseTy { BaseTyKind::Path(QPath::Resolved(_, Path { res: Res::PrimTy(PrimTy::Bool), .. })) ) } + + pub fn as_path(&self) -> Option<&Path> { + match &self.kind { + BaseTyKind::Path(QPath::Resolved(None, path)) => Some(path), + _ => None, + } + } } impl Res { @@ -725,10 +731,17 @@ impl Ident { /// in a definition. /// /// [early bound]: https://rustc-dev-guide.rust-lang.org/early-late-bound.html +/// +/// Sort parameters e.g. #[flux::refined_by( elems: Set )] tracks the mapping from +/// bound Var -> Generic id. e.g. if we have RMap refined_by(keys: Set) +/// then RMapIdx = forall #0. { keys: Set<#0> } +/// and sort_params = vec![T] i.e. maps Var(0) to T + #[derive(Clone, Debug, TyEncodable, TyDecodable)] pub struct RefinedBy { pub def_id: DefId, pub span: Span, + sort_params: Vec, /// Index parameters indexed by their name and in the same order they appear in the definition. index_params: FxIndexMap, /// The number of early bound parameters @@ -757,6 +770,7 @@ pub enum FuncKind { #[derive(Debug)] pub struct Defn { pub name: Symbol, + pub params: usize, pub args: Vec, pub sort: Sort, pub expr: Expr, @@ -766,6 +780,19 @@ impl Generics { pub(crate) fn get_param(&self, def_id: LocalDefId) -> &GenericParam { self.params.iter().find(|p| p.def_id == def_id).unwrap() } + + pub fn with_refined_by(self, refined_by: &RefinedBy) -> Self { + let mut params = vec![]; + for param in self.params { + let kind = if refined_by.is_base_generic(param.def_id.to_def_id()) { + GenericParamKind::SplTy + } else { + param.kind.clone() + }; + params.push(GenericParam { def_id: param.def_id, kind }); + } + Generics { params } + } } impl RefinedBy { @@ -773,6 +800,7 @@ impl RefinedBy { def_id: impl Into, early_bound_params: impl IntoIterator, index_params: impl IntoIterator, + sort_params: Vec, span: Span, ) -> Self { let mut sorts = early_bound_params.into_iter().collect_vec(); @@ -781,23 +809,44 @@ impl RefinedBy { .into_iter() .inspect(|(_, sort)| sorts.push(sort.clone())) .collect(); - RefinedBy { def_id: def_id.into(), span, index_params, early_bound, sorts } + RefinedBy { def_id: def_id.into(), sort_params, span, index_params, early_bound, sorts } + } + + pub fn trivial(def_id: impl Into, span: Span) -> Self { + RefinedBy { + def_id: def_id.into(), + sort_params: Default::default(), + span, + index_params: Default::default(), + early_bound: 0, + sorts: vec![], + } } pub fn field_index(&self, fld: Symbol) -> Option { self.index_params.get_index_of(&fld) } - pub fn field_sort(&self, fld: Symbol) -> Option<&Sort> { - self.index_params.get(&fld) + pub fn field_sort(&self, fld: Symbol, args: &[Sort]) -> Option { + self.index_params.get(&fld).map(|sort| sort.subst(args)) } - pub fn early_bound_sorts(&self) -> &[Sort] { - &self.sorts[..self.early_bound] + pub fn early_bound_sorts(&self, args: &[Sort]) -> Vec { + self.sorts[..self.early_bound] + .iter() + .map(|sort| sort.subst(args)) + .collect() } - pub fn index_sorts(&self) -> &[Sort] { - &self.sorts[self.early_bound..] + pub fn index_sorts(&self, args: &[Sort]) -> Vec { + self.sorts[self.early_bound..] + .iter() + .map(|sort| sort.subst(args)) + .collect() + } + + fn is_base_generic(&self, def_id: DefId) -> bool { + self.sort_params.contains(&def_id) } } @@ -835,7 +884,7 @@ impl Sort { Self::App(SortCtor::Map, List::from_vec(vec![k, v])) } - /// replace all "sort-parameters" (indexed 0...n-1) with the corresponding sort in `subst` + /// replace all "sort-vars" (indexed 0...n-1) with the corresponding sort in `subst` fn subst(&self, subst: &[Sort]) -> Sort { match self { Sort::Int @@ -846,14 +895,21 @@ impl Sort { | Sort::BitVec(_) | Sort::Param(_) | Sort::Wildcard - | Sort::Record(_) + | Sort::Record(_, _) | Sort::Infer(_) => self.clone(), Sort::Var(i) => subst[*i].clone(), Sort::App(c, args) => { let args = args.iter().map(|arg| arg.subst(subst)).collect(); Sort::App(c.clone(), args) } - Sort::Func(_) => bug!("unexpected subst in (nested) func-sort"), + Sort::Func(fsort) => { + if fsort.params == 0 { + let fsort = fsort.instantiate(subst); + Sort::Func(PolyFuncSort { params: 0, fsort }) + } else { + bug!("unexpected subst in (nested) func-sort") + } + } } } } @@ -1622,7 +1678,18 @@ impl fmt::Debug for Sort { Sort::Loc => write!(f, "loc"), Sort::Func(sort) => write!(f, "{sort}"), Sort::Unit => write!(f, "()"), - Sort::Record(def_id) => write!(f, "{}", pretty::def_id_to_string(*def_id)), + Sort::Record(def_id, sort_args) => { + if sort_args.is_empty() { + write!(f, "{}", pretty::def_id_to_string(*def_id)) + } else { + write!( + f, + "{}<{}>", + pretty::def_id_to_string(*def_id), + sort_args.iter().join(", ") + ) + } + } Sort::Param(def_id) => write!(f, "sortof({})", pretty::def_id_to_string(*def_id)), Sort::Wildcard => write!(f, "_"), Sort::Infer(vid) => write!(f, "{vid:?}"), diff --git a/crates/flux-middle/src/fhir/lift.rs b/crates/flux-middle/src/fhir/lift.rs index 3acca38b54..adc4c690d4 100644 --- a/crates/flux-middle/src/fhir/lift.rs +++ b/crates/flux-middle/src/fhir/lift.rs @@ -11,7 +11,7 @@ use rustc_hir as hir; use rustc_hir::def_id::LocalDefId; use rustc_middle::{middle::resolve_bound_vars::ResolvedArg, ty::TyCtxt}; -use crate::fhir; +use crate::{fhir, intern::List}; pub struct LiftCtxt<'a, 'tcx> { tcx: TyCtxt<'tcx>, @@ -25,7 +25,7 @@ pub fn lift_refined_by(tcx: TyCtxt, owner_id: OwnerId) -> fhir::RefinedBy { let item = tcx.hir().expect_item(def_id); match item.kind { hir::ItemKind::TyAlias(..) | hir::ItemKind::Struct(..) | hir::ItemKind::Enum(..) => { - fhir::RefinedBy::new(def_id, [], [], item.ident.span) + fhir::RefinedBy::trivial(def_id, item.ident.span) } _ => { bug!("expected struct, enum or type alias"); @@ -33,6 +33,14 @@ pub fn lift_refined_by(tcx: TyCtxt, owner_id: OwnerId) -> fhir::RefinedBy { } } +pub fn lift_generics( + tcx: TyCtxt, + sess: &FluxSession, + owner_id: OwnerId, +) -> Result { + LiftCtxt::new(tcx, sess, owner_id, None).lift_generics() +} + pub fn lift_type_alias( tcx: TyCtxt, sess: &FluxSession, @@ -151,6 +159,11 @@ impl<'a, 'tcx> LiftCtxt<'a, 'tcx> { self.lift_generics_inner(generics) } + pub fn lift_predicates(&mut self) -> Result { + let generics = self.tcx.hir().get_generics(self.owner.def_id).unwrap(); + self.lift_generic_predicates(generics) + } + fn lift_generic_param( &mut self, param: &hir::GenericParam, @@ -375,6 +388,7 @@ impl<'a, 'tcx> LiftCtxt<'a, 'tcx> { bty, idx: fhir::RefineArg::Record( self.owner.to_def_id(), + List::empty(), // TODO:RJ: or should we use the generics and just make it T1,...Tn? vec![], generics.span.shrink_to_hi(), ), diff --git a/crates/flux-middle/src/global_env.rs b/crates/flux-middle/src/global_env.rs index 8591b72045..cd39b0038d 100644 --- a/crates/flux-middle/src/global_env.rs +++ b/crates/flux-middle/src/global_env.rs @@ -16,7 +16,7 @@ use crate::{ fhir::{self, FluxLocalDefId, VariantIdx}, intern::List, queries::{Providers, Queries, QueryResult}, - rty::{self, fold::TypeFoldable, normalize::Defns, refining::Refiner}, + rty::{self, fold::TypeFoldable, normalize::Defns, refining::Refiner, GenericParamDefKind}, rustc::{self, ty}, }; @@ -180,63 +180,114 @@ impl<'sess, 'tcx> GlobalEnv<'sess, 'tcx> { pub fn sort_of_bty(&self, bty: &fhir::BaseTy) -> Option { match &bty.kind { - fhir::BaseTyKind::Path(fhir::QPath::Resolved(_, fhir::Path { res, .. })) => { - self.sort_of_res(*res) - } + fhir::BaseTyKind::Path(fhir::QPath::Resolved(_, path)) => self.sort_of_path(path), fhir::BaseTyKind::Slice(_) => Some(fhir::Sort::Int), } } - pub fn index_sorts_of(&self, def_id: impl Into) -> &[fhir::Sort] { + pub fn index_sorts_of( + &self, + def_id: impl Into, + sort_args: &[fhir::Sort], + ) -> Vec { let def_id = def_id.into(); if let Some(local_id) = def_id.as_local().or_else(|| self.map().get_extern(def_id)) { - self.map().refined_by(local_id).index_sorts() + self.map().refined_by(local_id).index_sorts(sort_args) } else { self.cstore() .refined_by(def_id) - .map(fhir::RefinedBy::index_sorts) + .map(|rby| rby.index_sorts(sort_args)) .unwrap_or_default() } } - pub fn sort_of_res(&self, res: fhir::Res) -> Option { - // CODESYNC(sort-of, 4) sorts should be given consistently - match res { + pub fn sort_of_path(&self, path: &fhir::Path) -> Option { + // CODESYNC(sort-of-path, 2) sorts should be given consistently + match path.res { fhir::Res::PrimTy(PrimTy::Int(_) | PrimTy::Uint(_)) => Some(fhir::Sort::Int), fhir::Res::PrimTy(PrimTy::Bool) => Some(fhir::Sort::Bool), fhir::Res::PrimTy(PrimTy::Float(..) | PrimTy::Str | PrimTy::Char) => { Some(fhir::Sort::Unit) } fhir::Res::Def(DefKind::TyAlias { .. } | DefKind::Enum | DefKind::Struct, def_id) => { - Some(fhir::Sort::Record(def_id)) - } - fhir::Res::SelfTyAlias { alias_to, .. } => { - let self_ty = self.tcx.type_of(alias_to).skip_binder(); - self.sort_of_self_ty(self_ty) - } - fhir::Res::Def(DefKind::TyParam, def_id) => { - let param = self.get_generic_param(def_id.expect_local()); - match ¶m.kind { - fhir::GenericParamKind::BaseTy => Some(fhir::Sort::Param(def_id)), - fhir::GenericParamKind::Type { .. } | fhir::GenericParamKind::Lifetime => None, - } + let mut sort_args = vec![]; + if let Ok(generics) = self.generics_of(def_id) { + for (param, arg) in generics.params.iter().zip(&path.args) { + if let GenericParamDefKind::SplTy = param.kind { + let fhir::GenericArg::Type(ty) = arg else { return None }; + let sort = self.sort_of_ty(ty)?; + sort_args.push(sort); + } + } + }; + Some(fhir::Sort::Record(def_id, List::from_vec(sort_args))) } + fhir::Res::SelfTyAlias { alias_to, .. } => self.sort_of_self_ty_alias(alias_to), + fhir::Res::Def(DefKind::TyParam, def_id) => self.sort_of_generic_param(def_id), fhir::Res::Def(DefKind::AssocTy | DefKind::OpaqueTy, _) | fhir::Res::SelfTyParam { .. } => None, - fhir::Res::Def(..) => bug!("unexpected res {res:?}"), + fhir::Res::Def(..) => bug!("unexpected res {:?}", path.res), + } + } + + pub fn sort_of_self_ty_alias(&self, alias_to: DefId) -> Option { + let self_ty = self.tcx.type_of(alias_to).instantiate_identity(); + self.sort_of_self_ty(alias_to, self_ty) + } + + fn sort_of_generic_param(&self, def_id: DefId) -> Option { + let param = self.get_generic_param(def_id.expect_local()); + match ¶m.kind { + fhir::GenericParamKind::BaseTy | fhir::GenericParamKind::SplTy => { + Some(fhir::Sort::Param(def_id)) + } + fhir::GenericParamKind::Type { .. } | fhir::GenericParamKind::Lifetime => None, + } + } + + fn sort_of_ty(&self, ty: &fhir::Ty) -> Option { + match &ty.kind { + fhir::TyKind::BaseTy(bty) | fhir::TyKind::Indexed(bty, _) => { + self.sort_of_path(bty.as_path()?) + } + fhir::TyKind::Exists(_, ty) | fhir::TyKind::Constr(_, ty) => self.sort_of_ty(ty), + fhir::TyKind::RawPtr(_, _) + | fhir::TyKind::Ref(_, _) + | fhir::TyKind::Tuple(_) + | fhir::TyKind::Array(_, _) + | fhir::TyKind::Never => Some(fhir::Sort::Unit), + fhir::TyKind::Hole(_) => Some(fhir::Sort::Wildcard), + fhir::TyKind::Ptr(_, _) => None, + fhir::TyKind::OpaqueDef(_, _, _, _) => None, } } - fn sort_of_self_ty(&self, ty: rustc_middle::ty::Ty) -> Option { + fn sort_of_self_ty(&self, def_id: DefId, ty: rustc_middle::ty::Ty) -> Option { use rustc_middle::ty; - // CODESYNC(sort-of, 4) sorts should be given consistently + // CODESYNC(sort-of, 3) sorts should be given consistently match ty.kind() { ty::TyKind::Bool => Some(fhir::Sort::Bool), ty::TyKind::Slice(_) | ty::TyKind::Int(_) | ty::TyKind::Uint(_) => { Some(fhir::Sort::Int) } - ty::TyKind::Adt(adt_def, _) => Some(fhir::Sort::Record(adt_def.did())), - ty::TyKind::Param(_) => todo!(), + ty::TyKind::Adt(adt_def, args) => { + let mut sort_args = vec![]; + for arg in *args { + if let Some(ty) = arg.as_type() && + let Some(sort) = self.sort_of_self_ty(def_id, ty) + { + sort_args.push(sort); + } + } + Some(fhir::Sort::Record(adt_def.did(), List::from_vec(sort_args))) + } + ty::TyKind::Param(p) => { + let generic_param_def = self + .tcx + .generics_of(def_id) + .param_at(p.index as usize, self.tcx); + self.sort_of_generic_param(generic_param_def.def_id) + } ty::TyKind::Float(_) | ty::TyKind::Str | ty::TyKind::Char @@ -249,13 +300,13 @@ impl<'sess, 'tcx> GlobalEnv<'sess, 'tcx> { } } - pub fn early_bound_sorts_of(&self, def_id: DefId) -> &[fhir::Sort] { + pub fn early_bound_sorts_of(&self, def_id: DefId, sort_args: &[fhir::Sort]) -> Vec { if let Some(local_id) = def_id.as_local() { - self.map().refined_by(local_id).early_bound_sorts() + self.map().refined_by(local_id).early_bound_sorts(sort_args) } else { self.cstore() .refined_by(def_id) - .map(fhir::RefinedBy::early_bound_sorts) + .map(|refined_by| refined_by.early_bound_sorts(sort_args)) .unwrap_or_default() } } @@ -268,18 +319,17 @@ impl<'sess, 'tcx> GlobalEnv<'sess, 'tcx> { | fhir::Sort::Real | fhir::Sort::Unit | fhir::Sort::BitVec(_) + | fhir::Sort::Param(_) | fhir::Sort::Var(_) => true, - fhir::Sort::Record(def_id) => { - self.index_sorts_of(*def_id) + fhir::Sort::Record(def_id, sort_args) => { + self.index_sorts_of(*def_id, sort_args) .iter() .all(|sort| self.has_equality(sort)) } fhir::Sort::App(ctor, sorts) => self.ctor_has_equality(ctor, sorts), - fhir::Sort::Loc - | fhir::Sort::Func(_) - | fhir::Sort::Param(_) - | fhir::Sort::Wildcard - | fhir::Sort::Infer(_) => false, + fhir::Sort::Loc | fhir::Sort::Func(_) | fhir::Sort::Wildcard | fhir::Sort::Infer(_) => { + false + } } } @@ -289,14 +339,20 @@ impl<'sess, 'tcx> GlobalEnv<'sess, 'tcx> { args.iter().all(|sort| self.has_equality(sort)) } - pub fn field_sort(&self, def_id: DefId, fld: Symbol) -> Option<&fhir::Sort> { - if let Some(local_id) = def_id.as_local() { - self.map().refined_by(local_id).field_sort(fld) + pub fn field_sort( + &self, + def_id: DefId, + sort_args: List, + fld: Symbol, + ) -> Option { + let poly_sort = if let Some(local_id) = def_id.as_local() { + self.map().refined_by(local_id).field_sort(fld, &sort_args) } else { self.cstore() .refined_by(def_id) - .and_then(|refined_by| refined_by.field_sort(fld)) - } + .and_then(|refined_by| refined_by.field_sort(fld, &sort_args)) + }; + poly_sort } pub fn field_index(&self, def_id: DefId, fld: Symbol) -> Option { diff --git a/crates/flux-middle/src/rty/mod.rs b/crates/flux-middle/src/rty/mod.rs index 7c8144ebb2..7f752fa380 100644 --- a/crates/flux-middle/src/rty/mod.rs +++ b/crates/flux-middle/src/rty/mod.rs @@ -83,6 +83,7 @@ pub struct GenericParamDef { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum GenericParamDefKind { Type { has_default: bool }, + SplTy, BaseTy, Lifetime, Const { has_default: bool }, @@ -205,7 +206,7 @@ pub struct AdtDef(Interned); #[derive(Debug, Eq, PartialEq, Hash, TyEncodable, TyDecodable)] pub struct AdtDefData { invariants: Vec, - sort: Sort, + sort: Sort, // TODO: Binder as there may be Var in `Sort` opaque: bool, rustc: rustc::ty::AdtDef, } @@ -438,10 +439,17 @@ impl GenericArg { bug!("expected `rty::GenericArg::Ty`, found {:?}", self) } } + pub fn is_valid_base_arg(&self) -> bool { + match self { + GenericArg::Ty(ty) => ty.kind().is_valid_base_ty(), + GenericArg::BaseTy(bty) => bty.as_ref().skip_binder().kind().is_valid_base_ty(), + _ => false, + } + } fn from_param_def(param: &GenericParamDef) -> Self { match param.kind { - GenericParamDefKind::Type { .. } => { + GenericParamDefKind::Type { .. } | GenericParamDefKind::SplTy => { let param_ty = ParamTy { index: param.index, name: param.name }; GenericArg::Ty(Ty::param(param_ty)) } @@ -1296,6 +1304,19 @@ impl TyKind { fn intern(self) -> Ty { Interned::new(TyS { kind: self }) } + + fn is_valid_base_ty(&self) -> bool { + match self { + TyKind::Param(_) | TyKind::Indexed(_, _) | TyKind::Exists(_) => true, + TyKind::Constr(_, ty) => ty.kind().is_valid_base_ty(), + TyKind::Uninit + | TyKind::Ptr(_, _) + | TyKind::Discr(_, _) + | TyKind::Downcast(_, _, _, _, _) + | TyKind::Blocked(_) + | TyKind::Alias(_, _) => false, + } + } } impl TyS { @@ -1468,7 +1489,7 @@ impl BaseTy { } pub fn sort(&self) -> Sort { - // CODESYNC(sort-of, 4) sorts should be given consistently + // CODESYNC(sort-of, 3) sorts should be given consistently match self { BaseTy::Int(_) | BaseTy::Uint(_) | BaseTy::Slice(_) => Sort::Int, BaseTy::Bool => Sort::Bool, diff --git a/crates/flux-middle/src/rty/refining.rs b/crates/flux-middle/src/rty/refining.rs index 5e00f3fba6..4853de11e8 100644 --- a/crates/flux-middle/src/rty/refining.rs +++ b/crates/flux-middle/src/rty/refining.rs @@ -231,6 +231,9 @@ impl<'a, 'tcx> Refiner<'a, 'tcx> { (rty::GenericParamDefKind::Type { .. }, rustc::ty::GenericArg::Ty(ty)) => { Ok(rty::GenericArg::Ty(self.refine_ty(ty)?)) } + (rty::GenericParamDefKind::SplTy, rustc::ty::GenericArg::Ty(ty)) => { + Ok(rty::GenericArg::Ty(self.refine_ty(ty)?)) + } (rty::GenericParamDefKind::BaseTy, rustc::ty::GenericArg::Ty(ty)) => { Ok(rty::GenericArg::BaseTy(self.refine_poly_ty(ty)?)) } @@ -320,7 +323,7 @@ impl<'a, 'tcx> Refiner<'a, 'tcx> { } rustc::ty::TyKind::Param(param_ty) => { match self.param(*param_ty)?.kind { - rty::GenericParamDefKind::Type { .. } => { + rty::GenericParamDefKind::Type { .. } | rty::GenericParamDefKind::SplTy => { return Ok(rty::Binder::new(rty::Ty::param(*param_ty), List::empty())); } rty::GenericParamDefKind::BaseTy => rty::BaseTy::Param(*param_ty), diff --git a/crates/flux-middle/src/rustc/ty.rs b/crates/flux-middle/src/rustc/ty.rs index 9d04e36e9f..f47139e96d 100644 --- a/crates/flux-middle/src/rustc/ty.rs +++ b/crates/flux-middle/src/rustc/ty.rs @@ -473,7 +473,8 @@ impl TyKind { impl Ty { pub fn mk_adt(adt_def: AdtDef, args: impl Into) -> Ty { - TyKind::Adt(adt_def, args.into()).intern() + let args = args.into(); + TyKind::Adt(adt_def, args).intern() } pub fn mk_closure(def_id: DefId, args: impl Into) -> Ty { diff --git a/crates/flux-refineck/locales/en-US.ftl b/crates/flux-refineck/locales/en-US.ftl index 35f93c1cc1..4909a54da7 100644 --- a/crates/flux-refineck/locales/en-US.ftl +++ b/crates/flux-refineck/locales/en-US.ftl @@ -28,6 +28,9 @@ refineck_assert_error = refineck_param_inference_error = parameter inference error at function call +refineck_invalid_generic_arg = + cannot instantiate base or spl generic with opaque type + refineck_fold_error = type invariant may not hold (when place is folded) diff --git a/crates/flux-refineck/src/checker.rs b/crates/flux-refineck/src/checker.rs index 9d5ebcc2c1..e76a1d7300 100644 --- a/crates/flux-refineck/src/checker.rs +++ b/crates/flux-refineck/src/checker.rs @@ -1344,6 +1344,7 @@ pub(crate) mod errors { Inference, OpaqueStruct(DefId), Query(QueryErr), + InvalidGenericArg, } impl CheckerError { @@ -1365,6 +1366,12 @@ pub(crate) mod errors { flux_errors::diagnostic_id(), ) } + CheckerErrKind::InvalidGenericArg => { + handler.struct_err_with_code( + fluent::refineck_invalid_generic_arg, + flux_errors::diagnostic_id(), + ) + } CheckerErrKind::OpaqueStruct(def_id) => { let mut builder = handler.struct_err_with_code( fluent::refineck_opaque_struct_error, diff --git a/crates/flux-refineck/src/constraint_gen.rs b/crates/flux-refineck/src/constraint_gen.rs index 97ff093f60..205d831996 100644 --- a/crates/flux-refineck/src/constraint_gen.rs +++ b/crates/flux-refineck/src/constraint_gen.rs @@ -9,8 +9,9 @@ use flux_middle::{ evars::{EVarCxId, EVarSol, UnsolvedEvar}, fold::TypeFoldable, AliasTy, BaseTy, BinOp, Binder, Const, Constraint, ESpan, EVarGen, EarlyBinder, Expr, - ExprKind, FnOutput, GeneratorObligPredicate, GenericArg, GenericArgs, HoleKind, InferMode, - Mutability, Path, PolyFnSig, PolyVariant, PtrKind, Ref, Sort, TupleTree, Ty, TyKind, Var, + ExprKind, FnOutput, GeneratorObligPredicate, GenericArg, GenericArgs, GenericParamDefKind, + HoleKind, InferMode, Mutability, Path, PolyFnSig, PolyVariant, PtrKind, Ref, Sort, + TupleTree, Ty, TyKind, Var, }, rustc::mir::{BasicBlock, Place}, }; @@ -146,6 +147,26 @@ impl<'a, 'tcx> ConstrGen<'a, 'tcx> { Ok(res) } + fn check_generic_args( + &self, + did: DefId, + generic_args: &[GenericArg], + ) -> Result<(), CheckerErrKind> { + let generics = self.genv.generics_of(did)?; + for (idx, arg) in generic_args.iter().enumerate() { + let param = generics.param_at(idx, self.genv)?; + match param.kind { + GenericParamDefKind::BaseTy => { + if !arg.is_valid_base_arg() { + return Err(CheckerErrKind::InvalidGenericArg); + } + } + _ => continue, + } + } + Ok(()) + } + #[allow(clippy::too_many_arguments)] pub(crate) fn check_fn_call( &mut self, @@ -184,6 +205,10 @@ impl<'a, 'tcx> ConstrGen<'a, 'tcx> { let callsite_def_id = self.def_id; let span = self.span; + if let Some(did) = callee_def_id { + self.check_generic_args(did, generic_args)?; + } + let mut infcx = self.infcx(rcx, ConstrReason::Call); let snapshot = rcx.snapshot(); @@ -698,7 +723,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { } } (GenericArg::BaseTy(_), GenericArg::BaseTy(_)) => { - tracked_span_bug!("sgeneric argument subtyping for base types is not implemented"); + tracked_span_bug!("generic argument subtyping for base types is not implemented"); } (GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => Ok(()), _ => tracked_span_bug!("incompatible generic args: `{arg1:?}` `{arg2:?}"), diff --git a/crates/flux-syntax/src/grammar.lalrpop b/crates/flux-syntax/src/grammar.lalrpop index 3a825585e2..f637d6fd2e 100644 --- a/crates/flux-syntax/src/grammar.lalrpop +++ b/crates/flux-syntax/src/grammar.lalrpop @@ -10,8 +10,8 @@ use lalrpop_util::ParseError; grammar(cx: &mut ParseCtxt<'_>); -Generics: surface::Generics = { - "<" > ">" => { +pub Generics: surface::Generics = { + > => { surface::Generics { params, span: cx.map_span(lo, hi), @@ -25,6 +25,7 @@ GenericParam: surface::GenericParam = { let kind = match kind.as_str() { "type" => surface::GenericParamKind::Type, "base" => surface::GenericParamKind::Base, + "spl" => surface::GenericParamKind::Spl, _ => return Err(ParseError::User { error: UserParseError::UnexpectedToken(lo, hi) }) }; Ok(surface::GenericParam { name, kind }) @@ -62,7 +63,7 @@ pub TyAlias: surface::TyAlias = { } pub RefinedBy: surface::RefinedBy = { - > => surface::RefinedBy { + > => surface::RefinedBy { index_params, early_bound_params: vec![], span: cx.map_span(lo, hi) @@ -92,14 +93,19 @@ Qualifier: surface::Qualifier = { } FuncDef: surface::FuncDef = { - "fn" "(" > ")" "->" "{" "}" => { - surface::FuncDef { name, args, output, body: Some(body) } + "fn" "(" > ")" "->" "{" "}" => { + surface::FuncDef { name, sort_vars: vars.unwrap_or_default(), args, output, body: Some(body) } }, - "fn" "(" > ")" "->" ";" => { - surface::FuncDef { name, args, output, body: None } + "fn" "(" > ")" "->" ";" => { + surface::FuncDef { name, sort_vars: vars.unwrap_or_default(), args, output, body: None } } } +SortVars: Vec = { + "<" > ">" => vars, +} + + SortDecl: surface::SortDecl = { "opaque" "sort" ";" => { surface::SortDecl { name } @@ -141,7 +147,7 @@ pub FnSig: surface::FnSig = { "fn" - + ")?> "(" ")" " )?> @@ -157,6 +163,7 @@ pub FnSig: surface::FnSig = { } else { surface::FnRetTy::Default(cx.map_span(ret_lo, ret_hi)) }; + let generics = generics.map(|z| z.1); surface::FnSig { asyncness, generics, diff --git a/crates/flux-syntax/src/lib.rs b/crates/flux-syntax/src/lib.rs index 13640fee8c..98805e3f1b 100644 --- a/crates/flux-syntax/src/lib.rs +++ b/crates/flux-syntax/src/lib.rs @@ -41,6 +41,14 @@ impl ParseSess { parse!(self, grammar::RefinedByParser, tokens, span) } + pub fn parse_generics( + &mut self, + tokens: &TokenStream, + span: Span, + ) -> ParseResult { + parse!(self, grammar::GenericsParser, tokens, span) + } + pub fn parse_type_alias( &mut self, tokens: &TokenStream, diff --git a/crates/flux-syntax/src/surface.rs b/crates/flux-syntax/src/surface.rs index d66196af49..9b1c440a80 100644 --- a/crates/flux-syntax/src/surface.rs +++ b/crates/flux-syntax/src/surface.rs @@ -38,6 +38,7 @@ pub struct Qualifier { #[derive(Debug)] pub struct FuncDef { pub name: Ident, + pub sort_vars: Vec, pub args: Vec, pub output: Sort, /// Body of the function. If not present this definition corresponds to an uninterpreted function. @@ -59,6 +60,7 @@ pub struct GenericParam { #[derive(Debug)] pub enum GenericParamKind { Type, + Spl, Base, Refine { sort: Sort }, } @@ -74,6 +76,7 @@ pub struct TyAlias { #[derive(Debug)] pub struct StructDef { + pub generics: Option, pub refined_by: Option, pub fields: Vec>, pub opaque: bool, @@ -247,6 +250,20 @@ pub enum TyKind { ImplTrait(NodeId, GenericBounds), } +impl Ty { + pub fn as_bty(&self) -> Option<&BaseTy> { + match &self.kind { + TyKind::Base(bty) | TyKind::Indexed { bty, .. } | TyKind::Exists { bty, .. } => { + Some(bty) + } + TyKind::GeneralExists { ty, .. } | TyKind::Constr(_, ty) => ty.as_bty(), + TyKind::Ref(_, _) + | TyKind::Tuple(_) + | TyKind::Array(_, _) + | TyKind::ImplTrait(_, _) => None, + } + } +} #[derive(Debug)] pub struct BaseTy { pub kind: BaseTyKind, diff --git a/crates/flux-tests/tests/lib/rmap.rs b/crates/flux-tests/tests/lib/rmap.rs index a8c429f586..8f74fbc6ef 100644 --- a/crates/flux-tests/tests/lib/rmap.rs +++ b/crates/flux-tests/tests/lib/rmap.rs @@ -1,32 +1,44 @@ #![allow(dead_code)] #![flux::defs { - fn map_set(m:Map, k: int, v: int) -> Map { map_store(m, k, v) } - fn map_get(m: Map, k:int) -> int { map_select(m, k) } - fn map_def(v:int) -> Map { map_default(v) } + fn map_set(m:Map, k: K, v: V) -> Map { map_store(m, k, v) } + fn map_get(m: Map, k:K) -> V { map_select(m, k) } + fn map_def(v:V) -> Map { map_default(v) } }] +use std::hash::Hash; + /// define a type indexed by a map #[flux::opaque] -#[flux::refined_by(map: Map)] -pub struct RMap { - inner: std::collections::HashMap, +#[flux::refined_by(vals: Map)] +pub struct RMap { + inner: std::collections::HashMap, } -impl RMap { +#[flux::generics(K as base, V as base)] +impl RMap { #[flux::trusted] + + /// #[flux::sig(fn() -> RMap{m: true})] "OK" i.e. wraps K, V in existential + /// #[flux::sig(fn() -> RMap{m: true})] "CRASH" i.e. wraps K, V in LAMBDA pub fn new() -> Self { Self { inner: std::collections::HashMap::new() } } #[flux::trusted] - #[flux::sig(fn(self: &strg RMap[@m], k: i32, v: i32) ensures self: RMap[map_set(m, k, v)])] - pub fn set(&mut self, k: i32, v: i32) { + #[flux::sig(fn(self: &strg RMap[@m], k: K, v: V) ensures self: RMap[map_set(m.vals, k, v)])] + pub fn set(&mut self, k: K, v: V) + where + K: Eq + Hash, + { self.inner.insert(k, v); } #[flux::trusted] - #[flux::sig(fn(&RMap[@m], k: i32) -> Option)] - pub fn get(&self, k: i32) -> Option { - self.inner.get(&k).copied() + #[flux::sig(fn(&RMap[@m], &K[@k]) -> Option<&V[map_get(m.vals, k)]>)] + pub fn get(&self, k: &K) -> Option<&V> + where + K: Eq + Hash, + { + self.inner.get(k) } } diff --git a/crates/flux-tests/tests/lib/rmapk.rs b/crates/flux-tests/tests/lib/rmapk.rs index c82b0a3036..0e142fa457 100644 --- a/crates/flux-tests/tests/lib/rmapk.rs +++ b/crates/flux-tests/tests/lib/rmapk.rs @@ -1,49 +1,64 @@ #![allow(dead_code)] #![flux::defs { - fn map_set(m:Map, k: int, v: int) -> Map { map_store(m, k, v) } - fn map_get(m: Map, k:int) -> int { map_select(m, k) } - fn map_def(v:int) -> Map { map_default(v) } - fn set_add(x: int, s: Set) -> Set { set_union(set_singleton(x), s) } - fn set_is_empty(s: Set) -> bool { s == set_empty(0) } - fn set_emp() -> Set { set_empty(0) } + fn map_set(m:Map, k: K, v: V) -> Map { map_store(m, k, v) } + fn map_get(m: Map, k:K) -> V { map_select(m, k) } + fn map_def(v:V) -> Map { map_default(v) } + fn set_add(x: T, s: Set) -> Set { set_union(set_singleton(x), s) } + fn set_is_empty(s: Set) -> bool { s == set_empty(0) } + fn set_emp() -> Set { set_empty(0) } }] +use std::hash::Hash; + /// define a type indexed by a map #[flux::opaque] -#[flux::refined_by(keys: Set, vals: Map)] -pub struct RMap { - inner: std::collections::HashMap, +#[flux::refined_by(keys: Set, vals: Map)] +pub struct RMap { + inner: std::collections::HashMap, } -impl RMap { +#[flux::generics(K as base, V as base)] +impl RMap { #[flux::trusted] - #[flux::sig(fn() -> RMap[set_empty(0), map_def(0)])] + #[flux::sig(fn() -> RMap{m: m.keys == set_empty(0)})] pub fn new() -> Self { Self { inner: std::collections::HashMap::new() } } #[flux::trusted] - #[flux::sig(fn(self: &strg RMap[@m], k: i32, v: i32) - ensures self: RMap[set_add(k, m.keys), map_set(m.vals, k, v)])] - pub fn set(&mut self, k: i32, v: i32) { + #[flux::sig(fn(self: &strg RMap[@m], k: K, v: V) + ensures self: RMap[set_add(k, m.keys), map_set(m.vals, k, v)])] + pub fn set(&mut self, k: K, v: V) + where + K: Eq + Hash, + { self.inner.insert(k, v); } #[flux::trusted] - #[flux::sig(fn(&RMap[@m], k: i32) -> Option)] - pub fn get(&self, k: i32) -> Option { - self.inner.get(&k).copied() + #[flux::sig(fn(&RMap[@m], &K[@k]) -> Option<&V[map_get(m.vals, k)]>)] + pub fn get(&self, k: &K) -> Option<&V> + where + K: Eq + Hash, + { + self.inner.get(k) } #[flux::trusted] - #[flux::sig(fn(&RMap[@m], k: i32) -> i32[map_get(m.vals, k)] requires set_is_in(k, m.keys))] - pub fn lookup(&self, k: i32) -> i32 { - *self.inner.get(&k).unwrap() + #[flux::sig(fn(&RMap[@m], &K[@k]) -> &V[map_get(m.vals, k)] requires set_is_in(k, m.keys))] + pub fn lookup(&self, k: &K) -> &V + where + K: Eq + Hash, + { + self.inner.get(k).unwrap() } #[flux::trusted] - #[flux::sig(fn(&RMap[@m], k: i32) -> bool[set_is_in(k, m.keys)])] - pub fn contains(&self, k: i32) -> bool { - self.inner.contains_key(&k) + #[flux::sig(fn(&RMap[@m], &K[@k]) -> bool[set_is_in(k, m.keys)])] + pub fn contains(&self, k: &K) -> bool + where + K: Eq + Hash, + { + self.inner.contains_key(k) } } diff --git a/crates/flux-tests/tests/lib/rset.rs b/crates/flux-tests/tests/lib/rset.rs new file mode 100644 index 0000000000..6895738128 --- /dev/null +++ b/crates/flux-tests/tests/lib/rset.rs @@ -0,0 +1,39 @@ +#![allow(dead_code)] + +use std::hash::Hash; + +#[flux::opaque] +#[flux::refined_by(elems: Set)] +pub struct RSet { + pub inner: std::collections::HashSet, +} + +// TODO: (RJ) I get some odd error with `T as spl` / cannot refine if I just remove this annotation! +// error: internal compiler error: crates/flux-middle/src/rty/subst.rs:353:30: expected base type for generic parameter, found `∃int. { i32[^0#0] | * }` +#[flux::generics(T as base)] +impl RSet { + #[flux::trusted] + #[flux::sig(fn() -> RSet[set_empty(0)])] + pub fn new() -> RSet { + let inner = std::collections::HashSet::new(); + RSet { inner } + } + + #[flux::trusted] + #[flux::sig(fn (set: &strg RSet[@s], elem: T) ensures set: RSet[set_union(set_singleton(elem), s)])] + pub fn insert(self: &mut Self, elem: T) + where + T: Eq + Hash, + { + self.inner.insert(elem); + } + + #[flux::trusted] + #[flux::sig(fn(set: &RSet[@s], &T[@elem]) -> bool[set_is_in(elem, s.elems)])] + pub fn contains(self: &Self, elem: &T) -> bool + where + T: Eq + Hash, + { + self.inner.contains(elem) + } +} diff --git a/crates/flux-tests/tests/neg/error_messages/wf/kinds00.rs b/crates/flux-tests/tests/neg/error_messages/wf/kinds00.rs new file mode 100644 index 0000000000..012a6e4b69 --- /dev/null +++ b/crates/flux-tests/tests/neg/error_messages/wf/kinds00.rs @@ -0,0 +1,15 @@ +#[path = "../../../lib/rset.rs"] +pub mod rset; + +use std::hash::Hash; + +use rset::RSet; + +pub fn test00_ok() -> Option { + Some(1) +} + +pub fn test00_bad() -> RSet { + //~^ ERROR values of this type cannot be used as base sorted instances + RSet::::new() +} diff --git a/crates/flux-tests/tests/neg/error_messages/wf/kinds01.rs b/crates/flux-tests/tests/neg/error_messages/wf/kinds01.rs new file mode 100644 index 0000000000..a4cb743b1c --- /dev/null +++ b/crates/flux-tests/tests/neg/error_messages/wf/kinds01.rs @@ -0,0 +1,17 @@ +#[path = "../../../lib/rset.rs"] +pub mod rset; + +use std::hash::Hash; + +use rset::RSet; + +// this is OK because we just dont generate an index for `soup` +#[flux::sig(fn(soup:RSet))] +pub fn test04(_s: RSet) {} + +#[flux::sig(fn(RSet[@salt]))] //~ ERROR type cannot be refined +pub fn test05(_s: RSet) +where + T: Eq + Hash, +{ +} diff --git a/crates/flux-tests/tests/neg/error_messages/wf/kinds02.rs b/crates/flux-tests/tests/neg/error_messages/wf/kinds02.rs new file mode 100644 index 0000000000..3ecc680ff6 --- /dev/null +++ b/crates/flux-tests/tests/neg/error_messages/wf/kinds02.rs @@ -0,0 +1,38 @@ +#[path = "../../../lib/rset.rs"] +pub mod rset; + +use std::hash::Hash; + +use rset::RSet; + +fn mk_eq_hash() -> impl Eq + Hash { + 0 +} + +#[flux::sig(fn(x:T) -> T[x])] +fn id(x: T) -> T { + x +} + +// This will try to call id with an `RSet` which can't be a "base" +pub fn test00() { + let z = mk_eq_hash(); + id(z); //~ ERROR cannot instantiate +} + +// This will try to create an `RSet` which can't be put into RSet +pub fn test01() { + let x = mk_eq_hash(); + let mut s = RSet::new(); //~ ERROR cannot instantiate + s.insert(x); +} + +// #[flux::sig(fn(x:T))] +// pub fn test01(x: T) { +// id(x); // TODO: REJECT-but-actually-ok +// } + +// fn test_bob(x: T) { +// let z = mk_eq_hash(); +// bob(z); // TODO: REJECT-but-actually-ok +// } diff --git a/crates/flux-tests/tests/neg/error_messages/wf/poly_sort.rs b/crates/flux-tests/tests/neg/error_messages/wf/poly_sort.rs new file mode 100644 index 0000000000..d9705df59c --- /dev/null +++ b/crates/flux-tests/tests/neg/error_messages/wf/poly_sort.rs @@ -0,0 +1,12 @@ +use std::hash::Hash; +#[flux::opaque] +#[flux::refined_by(elems: Set)] //~ ERROR cannot find sort `Tiger` +pub struct Foo { + pub inner: std::collections::HashSet, +} + +#[flux::opaque] +#[flux::refined_by(elems: Set)] //~ ERROR cannot find sort `Set` +pub struct Bar { + pub inner: std::collections::HashSet, +} diff --git a/crates/flux-tests/tests/neg/surface/maps00.rs b/crates/flux-tests/tests/neg/surface/maps00.rs index c705466366..b104432fc7 100644 --- a/crates/flux-tests/tests/neg/surface/maps00.rs +++ b/crates/flux-tests/tests/neg/surface/maps00.rs @@ -7,11 +7,13 @@ fn assert(_b: bool) {} pub fn test() { let mut m = RMap::new(); - m.set(10, 1); - m.set(20, 2); + let k0 = 10; + let k1 = 20; + let k2 = 30; - assert(1 + 1 == 2); - assert(m.get(10).unwrap() == 1); - assert(m.get(20).unwrap() == 2); - assert(m.get(30).unwrap() == 3); //~ ERROR refinement type + m.set(k0, 1); + m.set(k1, 2); + assert(*m.get(&k0).unwrap() == 1); + assert(*m.get(&k1).unwrap() == 2); + assert(*m.get(&k2).unwrap() == 3); //~ ERROR refinement type } diff --git a/crates/flux-tests/tests/neg/surface/maps01.rs b/crates/flux-tests/tests/neg/surface/maps01.rs index 2f73ef9944..b5cd006f45 100644 --- a/crates/flux-tests/tests/neg/surface/maps01.rs +++ b/crates/flux-tests/tests/neg/surface/maps01.rs @@ -7,13 +7,16 @@ fn assert(_b: bool) {} pub fn test() { let mut m = RMap::new(); - m.set(10, 1); - m.set(20, 2); + let k0 = 10; + let k1 = 20; + let k2 = 30; - assert(1 + 1 == 2); - assert(m.get(20).unwrap() == 2); - assert(m.lookup(10) == 1); - assert(m.lookup(20) == 2); - assert(m.contains(10)); - assert(m.contains(30)); //~ ERROR refinement type + m.set(k0, 1); + m.set(k1, 2); + + assert(*m.get(&k1).unwrap() == 2); + assert(*m.lookup(&k0) == 1); + assert(*m.lookup(&k1) == 2); + assert(m.contains(&k0)); + assert(m.contains(&k2)); //~ ERROR refinement type } diff --git a/crates/flux-tests/tests/neg/surface/rset00.rs b/crates/flux-tests/tests/neg/surface/rset00.rs new file mode 100644 index 0000000000..669cbee38a --- /dev/null +++ b/crates/flux-tests/tests/neg/surface/rset00.rs @@ -0,0 +1,49 @@ +use std::hash::Hash; +#[flux::opaque] +#[flux::refined_by(elems: Set)] +pub struct RSet { + pub inner: std::collections::HashSet, +} + +#[flux::trusted] +#[flux::sig(fn(s: RSet) -> RSet[s.elems])] +pub fn id(s: RSet) -> RSet { + s +} + +#[flux::sig(fn (bool[true]))] +fn assert(_b: bool) {} + +#[flux::trusted] +#[flux::sig(fn() -> RSet[set_empty(0)])] +pub fn empty() -> RSet { + let inner = std::collections::HashSet::new(); + RSet { inner } +} + +#[flux::trusted] +#[flux::sig(fn(set: &strg RSet[@s], elem: T) ensures set: RSet[set_union(set_singleton(elem), s)])] +pub fn insert(set: &mut RSet, elem: T) +where + T: Eq + Hash, +{ + set.inner.insert(elem); +} + +#[flux::trusted] +#[flux::sig(fn(set: &RSet[@s], &A[@elem]) -> bool[set_is_in(elem, s.elems)])] +pub fn contains(set: &RSet, elem: &A) -> bool +where + A: Eq + Hash, +{ + set.inner.contains(elem) +} + +pub fn test() { + let mut s = empty(); + let v0 = 666; + let v1 = 667; + insert(&mut s, v0); + assert(contains(&s, &v0)); + assert(contains(&s, &v1)); //~ ERROR refinement type +} diff --git a/crates/flux-tests/tests/neg/surface/rset01.rs b/crates/flux-tests/tests/neg/surface/rset01.rs new file mode 100644 index 0000000000..f0e219751f --- /dev/null +++ b/crates/flux-tests/tests/neg/surface/rset01.rs @@ -0,0 +1,25 @@ +#[path = "../../lib/rset.rs"] +mod rset; +use rset::RSet; + +#[flux::sig(fn (bool[true]))] +fn assert(_b: bool) {} + +pub fn test() { + let mut s = RSet::new(); + let v0 = 666; + let v1 = 667; + s.insert(v0); + assert(s.contains(&v0)); + assert(s.contains(&v1)); //~ ERROR refinement type +} + +#[flux::sig(fn () -> RSet)] +pub fn test1() -> RSet { + let mut s = RSet::new(); + let v0 = 666; + let v1 = 667; + s.insert(v0); + s.insert(v1); + s //~ ERROR refinement type +} diff --git a/crates/flux-tests/tests/pos/surface/rset00.rs b/crates/flux-tests/tests/pos/surface/rset00.rs new file mode 100644 index 0000000000..9a5f5f89de --- /dev/null +++ b/crates/flux-tests/tests/pos/surface/rset00.rs @@ -0,0 +1,69 @@ +use std::hash::Hash; +#[flux::opaque] +#[flux::refined_by( elems: Set )] +pub struct RSet { + pub inner: std::collections::HashSet, +} + +#[flux::trusted] +#[flux::sig(fn(s: RSet) -> RSet[s.elems])] +pub fn id(s: RSet) -> RSet { + s +} + +#[flux::sig(fn (bool[true]))] +fn assert(_b: bool) {} + +#[flux::trusted] +#[flux::sig(fn() -> RSet[set_empty(0)])] +pub fn empty() -> RSet { + let inner = std::collections::HashSet::new(); + RSet { inner } +} + +#[flux::trusted] +#[flux::sig(fn(set: &strg RSet[@s], elem: T) ensures set: RSet[set_union(set_singleton(elem), s)])] +pub fn insert(set: &mut RSet, elem: T) +where + T: Eq + Hash, +{ + set.inner.insert(elem); +} + +#[flux::trusted] +#[flux::sig(fn(set: &RSet[@soup], &A[@elem]) -> bool[set_is_in(elem, soup.elems)])] +pub fn contains(set: &RSet, elem: &A) -> bool +where + A: Eq + Hash, +{ + set.inner.contains(elem) +} + +pub fn test() { + let mut s = empty(); + let v0 = 666; + let v1 = 667; + insert(&mut s, v0); + assert(contains(&s, &v0)); + assert(!contains(&s, &v1)); +} + +// i32[10] + +#[flux::sig(fn(RSet 0}>[@s], y:i32{set_is_in(y, s.elems)}) )] +pub fn test2(s: RSet, y: i32) { + assert(contains(&s, &y)); +} + +#[flux::sig(fn(RSet[@s], y:T{set_is_in(y, s.elems)}))] +pub fn test3(s: RSet, y: T) +where + T: Eq + Hash, +{ + assert(contains(&s, &y)); +} + +#[flux::sig(fn(RSet[@s], y:i32{0 <= y && set_is_in(y, s.elems)}) )] +pub fn test4(s: RSet, y: i32) { + test3(s, y) +} diff --git a/crates/flux-tests/tests/pos/surface/rset01.rs b/crates/flux-tests/tests/pos/surface/rset01.rs new file mode 100644 index 0000000000..983f835fd8 --- /dev/null +++ b/crates/flux-tests/tests/pos/surface/rset01.rs @@ -0,0 +1,25 @@ +#[path = "../../lib/rset.rs"] +mod rset; +use rset::RSet; + +#[flux::sig(fn (bool[true]))] +fn assert(_b: bool) {} + +pub fn test() { + let mut s = RSet::new(); + let v0 = 666; + let v1 = 667; + s.insert(v0); + assert(s.contains(&v0)); + assert(!s.contains(&v1)); +} + +#[flux::sig(fn () -> RSet)] +pub fn test1() -> RSet { + let mut s = RSet::new(); + let v0 = 666; + let v1 = 667; + s.insert(v0); + s.insert(v1); + s +} diff --git a/crates/flux-tests/tests/todo/rset.rs b/crates/flux-tests/tests/todo/rset.rs deleted file mode 100644 index 347ce08900..0000000000 --- a/crates/flux-tests/tests/todo/rset.rs +++ /dev/null @@ -1,41 +0,0 @@ -#![allow(dead_code)] -#![flux::defs { - fn set_add(x: int, s: Set) -> Set { set_union(set_singleton(x), s) } -}] - -#[flux::opaque] -#[flux::refined_by(elems: Set)] -pub struct RSet { - inner: std::collections::HashSet, -} - -impl RSet { - #[flux::trusted] - #[flux::sig(fn() -> RSet[set_empty(0)])] - pub fn new() -> Self { - Self { inner: std::collections::HashSet::new() } - } - - #[flux::trusted] - #[flux::sig(fn(self: &strg RSet[@s], elem: T) - ensures self: RSet[set_add(k, s.elems)])] - pub fn insert(&mut self, elem: T) { - self.inner.insert(elem); - } - - #[flux::trusted] - #[flux::sig(fn(&Set[@s], &T[@elem]) -> bool[set_is_in(elem, s.elems)])] - pub fn contains(&self, elem: &T) -> bool { - self.inner.contains(elem) - } -} - -#[flux::sig(fn (bool[true]))] -fn assert(_b: bool) {} - -fn test() { - let mut s = RSet::new(); - s.insert(1); - assert(s.contains(1)); - assert(!s.contains(2)); -}