diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5b906ddd47..801e137684 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,7 +31,7 @@ jobs: with: enable-stack: true stack-version: "latest" - + - name: Install Z3 uses: cda-tum/setup-z3@v1.0.9 with: diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 066e12d536..b543518d4a 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -3,10 +3,8 @@ name: gh-pages on: push: branches: [main] - paths: book/** pull_request: branches: [main] - paths: book/** jobs: build: diff --git a/crates/flux-desugar/src/desugar.rs b/crates/flux-desugar/src/desugar.rs index ff4b95f6e8..0372e83a55 100644 --- a/crates/flux-desugar/src/desugar.rs +++ b/crates/flux-desugar/src/desugar.rs @@ -212,8 +212,7 @@ impl<'a, 'genv, 'tcx: 'genv> RustItemCtxt<'a, 'genv, 'tcx> { surface::GenericParamKind::Type => { fhir::GenericParamKind::Type { default: None } } - surface::GenericParamKind::Spl => fhir::GenericParamKind::SplTy, - surface::GenericParamKind::Base => fhir::GenericParamKind::BaseTy, + surface::GenericParamKind::Base => fhir::GenericParamKind::Base, surface::GenericParamKind::Refine { .. } => unreachable!(), }; self_kind = Some(kind); @@ -230,8 +229,7 @@ impl<'a, 'genv, 'tcx: 'genv> RustItemCtxt<'a, 'genv, 'tcx> { .transpose()?, } } - surface::GenericParamKind::Base => fhir::GenericParamKind::BaseTy, - surface::GenericParamKind::Spl => fhir::GenericParamKind::SplTy, + surface::GenericParamKind::Base => fhir::GenericParamKind::Base, surface::GenericParamKind::Refine { .. } => unreachable!(), }; surface_params.insert(def_id, fhir::GenericParam { def_id, kind }); diff --git a/crates/flux-fhir-analysis/locales/en-US.ftl b/crates/flux-fhir-analysis/locales/en-US.ftl index b647bd688c..86567b4760 100644 --- a/crates/flux-fhir-analysis/locales/en-US.ftl +++ b/crates/flux-fhir-analysis/locales/en-US.ftl @@ -80,9 +80,6 @@ fhir_analysis_expected_numeric = fhir_analysis_no_equality = values of sort `{$sort}` cannot be compared for equality -fhir_analysis_invalid_base_instance = - values of this type cannot be used as base sorted instances - fhir_analysis_param_not_determined = parameter `{$name}` cannot be determined .label = undetermined parameter @@ -166,6 +163,9 @@ fhir_analysis_assoc_type_not_found = .label = cannot resolve this associated type .note = Flux cannot resolved associated types if they are defined in a super trait +fhir_analysis_invalid_base_instance = + values of this type cannot be used as base sorted instances + # Check impl against trait errors fhir_analysis_incompatible_sort = diff --git a/crates/flux-fhir-analysis/src/conv.rs b/crates/flux-fhir-analysis/src/conv.rs index ba0014fefb..46e0c2781f 100644 --- a/crates/flux-fhir-analysis/src/conv.rs +++ b/crates/flux-fhir-analysis/src/conv.rs @@ -219,8 +219,7 @@ fn conv_generic_param_kind(kind: &fhir::GenericParamKind) -> rty::GenericParamDe fhir::GenericParamKind::Type { default } => { rty::GenericParamDefKind::Type { has_default: default.is_some() } } - fhir::GenericParamKind::SplTy => rty::GenericParamDefKind::SplTy, - fhir::GenericParamKind::BaseTy => rty::GenericParamDefKind::BaseTy, + fhir::GenericParamKind::Base => rty::GenericParamDefKind::Base, fhir::GenericParamKind::Lifetime => rty::GenericParamDefKind::Lifetime, } } @@ -426,7 +425,7 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> { clauses: &mut Vec, ) -> QueryResult { let mut into = vec![rty::GenericArg::Ty(bounded_ty.clone())]; - self.conv_generic_args_into(env, args, &mut into)?; + self.conv_generic_args_into(env, trait_id, args, &mut into)?; self.fill_generic_args_defaults(trait_id, &mut into)?; let trait_ref = rty::TraitRef { def_id: trait_id, args: into.into() }; let pred = rty::TraitPredicate { trait_ref }; @@ -453,9 +452,9 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> { let kind = rty::BoundRegionKind::BrNamed(def_id.to_def_id(), name); Ok(rty::BoundVariableKind::Region(kind)) } - fhir::GenericParamKind::Type { default: _ } - | fhir::GenericParamKind::BaseTy - | fhir::GenericParamKind::SplTy => bug!("unexpected!"), + fhir::GenericParamKind::Type { .. } | fhir::GenericParamKind::Base => { + bug!("unexpected!") + } } } @@ -673,7 +672,15 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> { fhir::TyKind::BaseTy(bty) => self.conv_base_ty(env, bty), fhir::TyKind::Indexed(bty, idx) => { let idx = self.conv_refine_arg(env, idx)?; - self.conv_indexed_type(env, bty, idx) + match &bty.kind { + fhir::BaseTyKind::Path(fhir::QPath::Resolved(_, path)) => { + Ok(self.conv_ty_ctor(env, path)?.replace_bound_reft(&idx)) + } + fhir::BaseTyKind::Slice(ty) => { + let bty = rty::BaseTy::Slice(self.conv_ty(env, ty)?); + Ok(rty::Ty::indexed(bty, idx)) + } + } } fhir::TyKind::Exists(params, ty) => { let layer = Layer::list(self, 0, params, false); @@ -738,50 +745,54 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> { } fn conv_base_ty(&self, env: &mut Env, bty: &fhir::BaseTy) -> QueryResult { - let sort = self.genv.sort_of_bty(bty); - - if let fhir::BaseTyKind::Path(fhir::QPath::Resolved(self_ty, path)) = &bty.kind { - if let fhir::Res::Def(DefKind::AssocTy, def_id) = path.res { - let self_ty = self.conv_ty(env, self_ty.as_deref().unwrap())?; - let [.., trait_segment, assoc_segment] = path.segments else { - span_bug!(bty.span, "expected at least two segments"); - }; - let mut args = vec![rty::GenericArg::Ty(self_ty)]; - self.conv_generic_args_into(env, trait_segment.args, &mut args)?; - self.conv_generic_args_into(env, assoc_segment.args, &mut args)?; - let args = List::from_vec(args); - - let refine_args = List::empty(); - let alias_ty = rty::AliasTy { args, refine_args, def_id }; - return Ok(rty::Ty::alias(rty::AliasKind::Projection, alias_ty)); - } - // If it is a type parameter with no sort, it means it is of kind `Type` - if let fhir::Res::SelfTyParam { .. } = path.res - && sort.is_none() - { - return Ok(rty::Ty::param(rty::SELF_PARAM_TY)); + match &bty.kind { + fhir::BaseTyKind::Path(fhir::QPath::Resolved(self_ty, path)) => { + match path.res { + fhir::Res::Def(DefKind::AssocTy, assoc_id) => { + let trait_id = self.genv.tcx().trait_of_item(assoc_id).unwrap(); + let self_ty = self.conv_ty(env, self_ty.as_deref().unwrap())?; + let [.., trait_segment, assoc_segment] = path.segments else { + span_bug!(bty.span, "expected at least two segments"); + }; + let mut args = vec![rty::GenericArg::Ty(self_ty)]; + self.conv_generic_args_into(env, trait_id, trait_segment.args, &mut args)?; + self.conv_generic_args_into(env, assoc_id, assoc_segment.args, &mut args)?; + let args = List::from_vec(args); + + let refine_args = List::empty(); + let alias_ty = rty::AliasTy { args, refine_args, def_id: assoc_id }; + return Ok(rty::Ty::alias(rty::AliasKind::Projection, alias_ty)); + } + fhir::Res::SelfTyParam { trait_ } => { + let param = self.genv.generics_of(trait_)?.param_at(0, self.genv)?; + if let rty::GenericParamDefKind::Type { .. } = param.kind { + return Ok(rty::Ty::param(rty::SELF_PARAM_TY)); + } + } + fhir::Res::Def(DefKind::TyParam, def_id) => { + let owner_id = self.genv.hir().ty_param_owner(def_id.expect_local()); + let param_ty = self.genv.def_id_to_param_ty(def_id.expect_local()); + let param = self + .genv + .generics_of(owner_id)? + .param_at(param_ty.index as usize, self.genv)?; + if let rty::GenericParamDefKind::Type { .. } = param.kind { + return Ok(rty::Ty::param(param_ty)); + } + } + _ => {} + } + Ok(self.conv_ty_ctor(env, path)?.to_ty()) } - if let fhir::Res::Def(DefKind::TyParam, def_id) = path.res - && sort.is_none() - { - let param_ty = self.genv.def_id_to_param_ty(def_id.expect_local()); - return Ok(rty::Ty::param(param_ty)); + fhir::BaseTyKind::Slice(ty) => { + let bty = rty::BaseTy::Slice(self.conv_ty(env, ty)?).shift_in_escaping(1); + let sort = bty.sort(); + Ok(rty::Ty::exists(rty::Binder::with_sort( + rty::Ty::indexed(bty, rty::Expr::nu()), + sort, + ))) } } - let sort = sort.unwrap(); - if sort.is_unit() { - let idx = rty::Expr::unit(); - self.conv_indexed_type(env, bty, idx) - } else if let Some(def_id) = sort.is_unit_adt() { - let idx = rty::Expr::unit_adt(def_id); - self.conv_indexed_type(env, bty, idx) - } else { - env.push_layer(Layer::empty()); - let idx = rty::Expr::nu(); - let ty = self.conv_indexed_type(env, bty, idx)?; - env.pop_layer(); - Ok(rty::Ty::exists(rty::Binder::with_sort(ty, sort))) - } } fn conv_lifetime(&self, env: &Env, lft: fhir::Lifetime) -> rty::Region { @@ -850,29 +861,7 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> { } } - fn conv_indexed_type( - &self, - env: &mut Env, - bty: &fhir::BaseTy, - idx: rty::Expr, - ) -> QueryResult { - match &bty.kind { - fhir::BaseTyKind::Path(fhir::QPath::Resolved(_, path)) => { - self.conv_indexed_path(env, path, idx) - } - fhir::BaseTyKind::Slice(ty) => { - let slice = rty::BaseTy::slice(self.conv_ty(env, ty)?); - Ok(rty::Ty::indexed(slice, idx)) - } - } - } - - fn conv_indexed_path( - &self, - env: &mut Env, - path: &fhir::Path, - idx: rty::Expr, - ) -> QueryResult { + fn conv_ty_ctor(&self, env: &mut Env, path: &fhir::Path) -> QueryResult { let bty = match &path.res { fhir::Res::PrimTy(PrimTy::Bool) => rty::BaseTy::Bool, fhir::Res::PrimTy(PrimTy::Str) => rty::BaseTy::Str, @@ -896,11 +885,7 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> { } fhir::Res::SelfTyParam { .. } => rty::BaseTy::Param(rty::SELF_PARAM_TY), fhir::Res::SelfTyAlias { alias_to, .. } => { - return Ok(self - .genv - .type_of(*alias_to)? - .instantiate_identity(&[]) - .replace_bound_reft(&idx)); + return Ok(self.genv.type_of(*alias_to)?.instantiate_identity(&[])); } fhir::Res::Def(DefKind::TyAlias { .. }, def_id) => { let generics = self.conv_generic_args(env, *def_id, path.last_segment().args)?; @@ -909,17 +894,15 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> { .iter() .map(|arg| self.conv_refine_arg(env, arg)) .try_collect_vec()?; - return Ok(self - .genv - .type_of(*def_id)? - .instantiate(&generics, &refine) - .replace_bound_reft(&idx)); + return Ok(self.genv.type_of(*def_id)?.instantiate(&generics, &refine)); } fhir::Res::Def(..) => { span_bug!(path.span, "unexpected resolution in conv_indexed_path: {:?}", path.res) } }; - Ok(rty::Ty::indexed(bty, idx)) + let sort = bty.sort(); + let bty = bty.shift_in_escaping(1); + Ok(rty::Binder::with_sort(rty::Ty::indexed(bty, rty::Expr::nu()), sort)) } pub fn conv_generic_args( @@ -929,7 +912,7 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> { args: &[fhir::GenericArg], ) -> QueryResult> { let mut into = vec![]; - self.conv_generic_args_into(env, args, &mut into)?; + self.conv_generic_args_into(env, def_id, args, &mut into)?; self.fill_generic_args_defaults(def_id, &mut into)?; Ok(into) } @@ -937,17 +920,35 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> { fn conv_generic_args_into( &self, env: &mut Env, + def_id: DefId, args: &[fhir::GenericArg], into: &mut Vec, - ) -> QueryResult<()> { - for arg in args { - match arg { - fhir::GenericArg::Lifetime(lft) => { + ) -> QueryResult { + let generics = self.genv.generics_of(def_id)?; + for (idx, arg) in args.iter().enumerate() { + let param = generics.param_at(idx, self.genv)?; + match (arg, ¶m.kind) { + (fhir::GenericArg::Lifetime(lft), rty::GenericParamDefKind::Lifetime) => { into.push(rty::GenericArg::Lifetime(self.conv_lifetime(env, *lft))); } - fhir::GenericArg::Type(ty) => { + (fhir::GenericArg::Type(ty), rty::GenericParamDefKind::Type { .. }) => { into.push(rty::GenericArg::Ty(self.conv_ty(env, ty)?)); } + (fhir::GenericArg::Type(ty), rty::GenericParamDefKind::Base) => { + let ctor = self + .conv_ty(env, ty)? + .shallow_canonicalize() + .to_subset_ty_ctor() + .ok_or_else(|| { + self.genv + .sess() + .emit_err(errors::InvalidBaseInstance::new(ty)) + })?; + into.push(rty::GenericArg::Base(ctor)); + } + _ => { + bug!("unexpected param `{:?}` for arg `{arg:?}`", param.kind); + } } } Ok(()) @@ -957,7 +958,7 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> { &self, def_id: DefId, into: &mut Vec, - ) -> QueryResult<()> { + ) -> QueryResult { let generics = self.genv.generics_of(def_id)?; for param in generics.params.iter().skip(into.len()) { if let rty::GenericParamDefKind::Type { has_default } = param.kind { @@ -966,7 +967,7 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> { .genv .type_of(param.def_id)? .instantiate(into, &[]) - .into_ty(); + .to_ty(); into.push(rty::GenericArg::Ty(ty)); } else { bug!("unexpected generic param: {param:?}"); @@ -1203,10 +1204,6 @@ impl Layer { Self::new(cx, 0, params, false, LayerKind::Record(def_id)) } - fn empty() -> Self { - Self { map: FxIndexMap::default(), filter_unit: false, kind: LayerKind::List } - } - fn get(&self, name: impl Borrow, level: u32) -> Option { Some(LookupResultKind::LateBound { level, @@ -1478,7 +1475,7 @@ fn conv_un_op(op: fhir::UnOp) -> rty::UnOp { mod errors { use flux_macros::Diagnostic; - use flux_middle::fhir::SurfaceIdent; + use flux_middle::fhir::{self, SurfaceIdent}; use rustc_span::Span; #[derive(Diagnostic)] @@ -1495,4 +1492,18 @@ mod errors { Self { span: ident.span } } } + + #[derive(Diagnostic)] + #[diag(fhir_analysis_invalid_base_instance, code = "FLUX")] + pub(super) struct InvalidBaseInstance<'fhir> { + #[primary_span] + span: Span, + ty: &'fhir fhir::Ty<'fhir>, + } + + impl<'fhir> InvalidBaseInstance<'fhir> { + pub(super) fn new(ty: &'fhir fhir::Ty<'fhir>) -> Self { + Self { ty, span: ty.span } + } + } } diff --git a/crates/flux-fhir-analysis/src/lib.rs b/crates/flux-fhir-analysis/src/lib.rs index 38044cb22b..ebe95a4e32 100644 --- a/crates/flux-fhir-analysis/src/lib.rs +++ b/crates/flux-fhir-analysis/src/lib.rs @@ -291,7 +291,7 @@ fn refinement_generics_of( } } -fn type_of(genv: GlobalEnv, def_id: LocalDefId) -> QueryResult> { +fn type_of(genv: GlobalEnv, def_id: LocalDefId) -> QueryResult> { let ty = match genv.def_kind(def_id) { DefKind::TyAlias { .. } => { let alias = genv.map().expect_item(def_id).expect_type_alias(); @@ -310,7 +310,7 @@ fn type_of(genv: GlobalEnv, def_id: LocalDefId) -> QueryResult { let generics = genv.generics_of(def_id)?; let ty = genv.lower_type_of(def_id)?.skip_binder(); - Refiner::default(genv, &generics).refine_poly_ty(&ty)? + Refiner::default(genv, &generics).refine_ty_ctor(&ty)? } kind => { bug!("`{:?}` not supported", kind.descr(def_id.to_def_id())) diff --git a/crates/flux-fhir-analysis/src/wf/errors.rs b/crates/flux-fhir-analysis/src/wf/errors.rs index e94639f48a..59adc3ee9d 100644 --- a/crates/flux-fhir-analysis/src/wf/errors.rs +++ b/crates/flux-fhir-analysis/src/wf/errors.rs @@ -184,20 +184,6 @@ impl<'a> InvalidPrimitiveDotAccess<'a> { } } -#[derive(Diagnostic)] -#[diag(fhir_analysis_invalid_base_instance, code = "FLUX")] -pub(super) struct InvalidBaseInstance<'fhir> { - #[primary_span] - span: Span, - ty: fhir::Ty<'fhir>, -} - -impl<'fhir> InvalidBaseInstance<'fhir> { - pub(super) fn new(ty: fhir::Ty<'fhir>) -> Self { - Self { ty, span: ty.span } - } -} - #[derive(Diagnostic)] #[diag(fhir_analysis_param_not_determined, code = "FLUX")] #[help] diff --git a/crates/flux-fhir-analysis/src/wf/mod.rs b/crates/flux-fhir-analysis/src/wf/mod.rs index 2ecc3bc698..b2bad18b56 100644 --- a/crates/flux-fhir-analysis/src/wf/mod.rs +++ b/crates/flux-fhir-analysis/src/wf/mod.rs @@ -12,7 +12,7 @@ use flux_errors::{FluxSession, ResultExt}; use flux_middle::{ fhir::{self, ExprRes, FluxOwnerId, SurfaceIdent}, global_env::GlobalEnv, - rty::{self, GenericParamDefKind, WfckResults}, + rty::{self, WfckResults}, }; use rustc_data_structures::{ snapshot_map::{self, SnapshotMap}, @@ -20,7 +20,7 @@ use rustc_data_structures::{ }; use rustc_errors::{ErrorGuaranteed, IntoDiagnostic}; use rustc_hash::FxHashSet; -use rustc_hir::{def::DefKind, def_id::DefId, OwnerId}; +use rustc_hir::{def::DefKind, OwnerId}; use rustc_span::Symbol; use self::sortck::InferCtxt; @@ -37,7 +37,7 @@ struct Wf<'genv, 'tcx> { /// determined. The context is called Xi because in the paper [Focusing on Liquid Refinement Typing], /// the well-formedness judgment uses an uppercase Xi (Ξ) for a context that is similar in purpose. /// -/// This is basically a set of [`fhir::Name`] implemented with a snapshot map such that elements +/// This is basically a set of [`fhir::ParamId`] implemented with a snapshot map such that elements /// can be removed in batch when there's a change in polarity. /// /// [Focusing on Liquid Refinement Typing]: https://arxiv.org/pdf/2209.13000.pdf @@ -255,9 +255,8 @@ impl<'genv, 'tcx> Wf<'genv, 'tcx> { fn check_generic_bound(&mut self, infcx: &mut InferCtxt, bound: &fhir::GenericBound) -> Result { match bound { fhir::GenericBound::Trait(trait_ref, _) => self.check_path(infcx, &trait_ref.trait_ref), - fhir::GenericBound::LangItemTrait(lang_item, args, bindings) => { - let def_id = self.genv.tcx().require_lang_item(*lang_item, None); - self.check_generic_args(infcx, def_id, args)?; + fhir::GenericBound::LangItemTrait(_, args, bindings) => { + self.check_generic_args(infcx, args)?; self.check_type_bindings(infcx, bindings)?; Ok(()) } @@ -385,10 +384,9 @@ impl<'genv, 'tcx> Wf<'genv, 'tcx> { self.check_type(infcx, ty)?; self.check_expr_as_pred(infcx, pred) } - fhir::TyKind::OpaqueDef(item_id, args, _refine_args, _) => { + fhir::TyKind::OpaqueDef(_item_id, args, _refine_args, _) => { // TODO sanity check the _refine_args (though they should never fail!) but we'd need their expected sorts - let def_id = item_id.owner_id.to_def_id(); - self.check_generic_args(infcx, def_id, args) + self.check_generic_args(infcx, args) } fhir::TyKind::RawPtr(ty, _) => self.check_type(infcx, ty), fhir::TyKind::Hole(_) | fhir::TyKind::Never => Ok(()), @@ -408,26 +406,7 @@ impl<'genv, 'tcx> Wf<'genv, 'tcx> { .try_collect_exhaust() } - fn check_generic_args_kinds(&self, def_id: DefId, args: &[fhir::GenericArg]) -> Result { - let generics = self.genv.generics_of(def_id).emit(self.genv.sess())?; - for (arg, param) in iter::zip(args, &generics.params) { - if param.kind == GenericParamDefKind::SplTy { - let ty = arg.expect_type(); - if self.genv.sort_of_ty(ty).is_none() { - return self.emit_err(errors::InvalidBaseInstance::new(*ty)); - } - } - } - Ok(()) - } - - fn check_generic_args( - &mut self, - infcx: &mut InferCtxt, - def_id: DefId, - args: &[fhir::GenericArg], - ) -> Result { - self.check_generic_args_kinds(def_id, args)?; + fn check_generic_args(&mut self, infcx: &mut InferCtxt, args: &[fhir::GenericArg]) -> Result { args.iter() .try_for_each_exhaust(|arg| self.check_generic_arg(infcx, arg)) } @@ -490,10 +469,8 @@ impl<'genv, 'tcx> Wf<'genv, 'tcx> { // TODO(nilehmann) we should check all segments let last_segment = path.last_segment(); - if let fhir::Res::Def(_kind, did) = &path.res - && !last_segment.args.is_empty() - { - self.check_generic_args(infcx, *did, last_segment.args)?; + if !last_segment.args.is_empty() { + self.check_generic_args(infcx, last_segment.args)?; } let bindings = self.check_type_bindings(infcx, last_segment.bindings); if !self.genv.is_box(path.res) { diff --git a/crates/flux-fhir-analysis/src/wf/sortck.rs b/crates/flux-fhir-analysis/src/wf/sortck.rs index ccc76e40d1..10ce73215a 100644 --- a/crates/flux-fhir-analysis/src/wf/sortck.rs +++ b/crates/flux-fhir-analysis/src/wf/sortck.rs @@ -352,9 +352,9 @@ impl<'genv> InferCtxt<'genv, '_> { } /// Whether a value of `sort1` can be automatically coerced to a value of `sort2`. A value of an - /// [`rty::Sort::Adt`] sort with a single field of sort `s` can be coerced to a value of sort `s` - /// and vice versa, i.e., we can automatically project the field out of the record or inject a - /// value into a record. + /// [`rty::SortCtor::Adt`] sort with a single field of sort `s` can be coerced to a value of sort + /// `s` and vice versa, i.e., we can automatically project the field out of the record or inject + /// a value into a record. fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool { if self.try_equate(sort1, sort2).is_some() { return true; diff --git a/crates/flux-metadata/src/lib.rs b/crates/flux-metadata/src/lib.rs index 101c0962a1..2a8dcf9095 100644 --- a/crates/flux-metadata/src/lib.rs +++ b/crates/flux-metadata/src/lib.rs @@ -49,7 +49,7 @@ pub struct CrateMetadata { fn_sigs: FxHashMap>, adts: FxHashMap, /// For now it only store type of aliases - type_of: FxHashMap>, + type_of: FxHashMap>, } #[derive(TyEncodable, TyDecodable)] @@ -98,7 +98,7 @@ impl CrateStore for CStore { .map(|adt| adt.variants.as_ref().map(rty::EarlyBinder::as_deref)) } - fn type_of(&self, def_id: DefId) -> Option<&rty::EarlyBinder> { + fn type_of(&self, def_id: DefId) -> Option<&rty::EarlyBinder> { self.meta.get(&def_id.krate)?.type_of.get(&def_id.index) } } diff --git a/crates/flux-middle/locales/en-US.ftl b/crates/flux-middle/locales/en-US.ftl index d17d189902..1cd8fabcf2 100644 --- a/crates/flux-middle/locales/en-US.ftl +++ b/crates/flux-middle/locales/en-US.ftl @@ -24,3 +24,6 @@ middle_unsupported_generic_bound = middle_query_unsupported = unsupported signature .note = {$reason} + +middle_query_invalid_generic_arg = + cannot instantiate base generic with opaque type or a type parameter of kind type diff --git a/crates/flux-middle/src/cstore.rs b/crates/flux-middle/src/cstore.rs index 5368d5d3f5..93b5d96c94 100644 --- a/crates/flux-middle/src/cstore.rs +++ b/crates/flux-middle/src/cstore.rs @@ -9,7 +9,7 @@ pub trait CrateStore { &self, def_id: DefId, ) -> Option>>; - fn type_of(&self, def_id: DefId) -> Option<&rty::EarlyBinder>; + fn type_of(&self, def_id: DefId) -> Option<&rty::EarlyBinder>; } pub type CrateStoreDyn = dyn CrateStore; diff --git a/crates/flux-middle/src/fhir.rs b/crates/flux-middle/src/fhir.rs index 9db8d042b3..26b2bdfae3 100644 --- a/crates/flux-middle/src/fhir.rs +++ b/crates/flux-middle/src/fhir.rs @@ -62,8 +62,7 @@ pub struct GenericParam<'fhir> { #[derive(Debug, Clone, Copy)] pub enum GenericParamKind<'fhir> { Type { default: Option> }, - SplTy, - BaseTy, + Base, Lifetime, } @@ -475,12 +474,11 @@ pub enum TyKind<'fhir> { /// A type that parses as a [`BaseTy`] but was written without refinements. Most types in /// this category are base types and will be converted into an [existential], e.g., `i32` is /// converted into `∃v:int. i32[v]`. However, this category also contains generic variables - /// of kind [type] or [*special*]. We cannot distinguish these syntactially so we resolve them - /// later in the analysis. + /// of kind [type]. We cannot distinguish these syntactially so we resolve them later in the + /// analysis. /// /// [existential]: crate::rty::TyKind::Exists /// [type]: GenericParamKind::Type - /// [*special*]: GenericParamKind::SplTy BaseTy(BaseTy<'fhir>), Indexed(BaseTy<'fhir>, RefineArg<'fhir>), Exists(&'fhir [RefineParam<'fhir>], &'fhir Ty<'fhir>), @@ -1004,7 +1002,7 @@ impl<'fhir> Generics<'fhir> { pub fn with_refined_by(self, genv: GlobalEnv<'fhir, '_>, refined_by: &RefinedBy) -> Self { let params = genv.alloc_slice_fill_iter(self.params.iter().map(|param| { let kind = if refined_by.is_base_generic(param.def_id.to_def_id()) { - GenericParamKind::SplTy + GenericParamKind::Base } else { param.kind }; diff --git a/crates/flux-middle/src/global_env.rs b/crates/flux-middle/src/global_env.rs index ab08971880..db30b19dfe 100644 --- a/crates/flux-middle/src/global_env.rs +++ b/crates/flux-middle/src/global_env.rs @@ -213,12 +213,10 @@ impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> { let Some(poly_trait_ref) = self.tcx().impl_trait_ref(impl_id) else { return Ok(None) }; - let impl_generics = self.generics_of(impl_id)?; - let trait_ref = poly_trait_ref.skip_binder(); - let args = lowering::lower_generic_args(self.tcx(), trait_ref.args) + let trait_ref = lowering::lower_trait_ref(self.tcx(), poly_trait_ref.skip_binder()) .map_err(|err| QueryErr::unsupported(self.tcx(), impl_id, err.into_err()))?; - let args = self.refine_default_generic_args(&impl_generics, &args)?; - let trait_ref = rty::TraitRef { def_id: trait_ref.def_id, args }; + let impl_generics = self.generics_of(impl_id)?; + let trait_ref = Refiner::default(self, &impl_generics).refine_trait_ref(&trait_ref)?; Ok(Some(rty::EarlyBinder(trait_ref))) } @@ -268,7 +266,7 @@ impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> { self.inner.queries.item_bounds(self, def_id) } - pub fn type_of(self, def_id: DefId) -> QueryResult> { + pub fn type_of(self, def_id: DefId) -> QueryResult> { self.inner.queries.type_of(self, def_id) } @@ -322,19 +320,6 @@ impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> { generics.param_def_id_to_index[&def_id.to_def_id()] } - pub fn refine_default_generic_args( - self, - generics: &rty::Generics, - args: &ty::GenericArgs, - ) -> QueryResult { - let refiner = Refiner::default(self, generics); - let mut res = vec![]; - for arg in args { - res.push(refiner.refine_generic_arg_raw(arg)?); - } - Ok(res.into()) - } - pub fn refine_default( self, generics: &rty::Generics, diff --git a/crates/flux-middle/src/pretty.rs b/crates/flux-middle/src/pretty.rs index f4fb892c2a..f43e773f71 100644 --- a/crates/flux-middle/src/pretty.rs +++ b/crates/flux-middle/src/pretty.rs @@ -175,7 +175,7 @@ struct Env { impl Env { fn lookup(&self, debruijn: DebruijnIndex, index: u32) -> Option { self.layers - .get(self.layers.len() - debruijn.as_usize() - 1)? + .get(self.layers.len().checked_sub(debruijn.as_usize() + 1)?)? .get(&index) .copied() } @@ -328,7 +328,7 @@ impl PrettyCx<'_> { if let Some(name) = self.env.borrow().lookup(debruijn, var.index) { w!("{name:?}") } else { - w!("(⭡{debruijn:?}.{})", ^var.index) + w!("⭡{}/#{}", ^debruijn.as_usize(), ^var.index) } } BoundReftKind::Named(name) => w!("{name}"), diff --git a/crates/flux-middle/src/queries.rs b/crates/flux-middle/src/queries.rs index efc7a4fed8..306c51342f 100644 --- a/crates/flux-middle/src/queries.rs +++ b/crates/flux-middle/src/queries.rs @@ -40,6 +40,7 @@ pub type QueryResult = Result; #[derive(Debug, Clone)] pub enum QueryErr { Unsupported { def_id: DefId, def_span: Span, err: UnsupportedErr }, + InvalidGenericArg { def_id: DefId, def_span: Span }, Emitted(ErrorGuaranteed), } @@ -56,7 +57,7 @@ pub struct Providers { FluxLocalDefId, ) -> QueryResult>>, pub adt_def: fn(GlobalEnv, LocalDefId) -> QueryResult, - pub type_of: fn(GlobalEnv, LocalDefId) -> QueryResult>, + pub type_of: fn(GlobalEnv, LocalDefId) -> QueryResult>, pub variants_of: fn( GlobalEnv, LocalDefId, @@ -129,7 +130,7 @@ pub struct Queries<'genv, 'tcx> { assoc_refinement_def: Cache<(DefId, Symbol), QueryResult>>, sort_of_assoc_reft: Cache<(DefId, Symbol), Option>>, item_bounds: Cache>>>, - type_of: Cache>>, + type_of: Cache>>, variants_of: Cache>>>, fn_sig: Cache>>, lower_late_bound_vars: Cache>>, @@ -443,7 +444,7 @@ impl<'genv, 'tcx> Queries<'genv, 'tcx> { &self, genv: GlobalEnv, def_id: DefId, - ) -> QueryResult> { + ) -> QueryResult> { run_with_cache(&self.type_of, def_id, || { if let Some(local_id) = def_id.as_local() { (self.providers.type_of)(genv, local_id) @@ -576,6 +577,14 @@ impl<'a> IntoDiagnostic<'a> for QueryErr { builder.downgrade_to_delayed_bug(); builder } + QueryErr::InvalidGenericArg { def_span, .. } => { + let builder = handler.struct_span_err_with_code( + def_span, + fluent::middle_query_invalid_generic_arg, + flux_errors::diagnostic_id(), + ); + builder + } } } } diff --git a/crates/flux-middle/src/rty/canonicalize.rs b/crates/flux-middle/src/rty/canonicalize.rs new file mode 100644 index 0000000000..92b7c79f42 --- /dev/null +++ b/crates/flux-middle/src/rty/canonicalize.rs @@ -0,0 +1,195 @@ +//! A canonical type is a type where all [existentials] and [constraint predicates] are *hoisted* to +//! the top level. For example, the canonical version of `(∃a. i32[a], ∃b. { i32[b] | b > 0})` is +//! `∃a,b. { (i32[a], i32[b]) | b > 0}`. +//! +//! Canonicalization can be *shallow* or *deep*, by this we mean that some type constructors +//! introduce new "scopes" that limit the hoisting. For instance, we are not allowed (in general) to +//! hoist an existential type out of a generic argument, for example, in `Vec<∃v. i32[v]>` the +//! existential inside the `Vec` cannot be hoisted out. However, the type inside the generic argument +//! can be canonizalized locally inside the scope of the generic argument. Shallow canonicalization +//! stops when finding type constructors. In contrast, deep canonicalization also canonizalizes inside +//! type constructors. Note that some type constructors like shared references or boxes are transparent +//! to hoisting and do not introduce a new scope. +//! +//! It's also important to note that canonizalization doesn't imply any form of semantic equality +//! and it is just a best effort to facilitate syntactic manipulation. For example, the types +//! `∃a,b. (i32[a], i32[b])` and `∃a,b. (i32[b], i32[a])` are semantically equal but both are in +//! canonical form (in the current implementation). +//! +//! [existentials]: TyKind::Exists +//! [constraint predicates]: TyKind::Constr +use rustc_type_ir::{Mutability, INNERMOST}; + +use super::{ + box_args, + fold::{TypeFoldable, TypeFolder}, + BaseTy, Binder, BoundVariableKind, Expr, GenericArg, SubsetTy, SubsetTyCtor, Ty, TyKind, +}; +use crate::intern::List; + +#[derive(Default)] +pub struct ShallowHoister { + vars: Vec, + preds: Vec, +} + +impl ShallowHoister { + pub fn into_parts(self) -> (List, Vec) { + (List::from_vec(self.vars), self.preds) + } + + pub fn hoist(&mut self, ty: &Ty) -> Ty { + ty.fold_with(self) + } +} + +impl TypeFolder for ShallowHoister { + fn fold_ty(&mut self, ty: &Ty) -> Ty { + match ty.kind() { + TyKind::Indexed(bty, idx) => Ty::indexed(bty.fold_with(self), idx.clone()), + TyKind::Exists(ty) => { + ty.replace_bound_refts_with(|sort, mode, kind| { + let idx = self.vars.len(); + self.vars + .push(BoundVariableKind::Refine(sort.clone(), mode, kind)); + Expr::late_bvar(INNERMOST, idx as u32, kind) + }) + .fold_with(self) + } + TyKind::Constr(pred, ty) => { + self.preds.push(pred.clone()); + ty.fold_with(self) + } + _ => ty.clone(), + } + } + + fn fold_bty(&mut self, bty: &BaseTy) -> BaseTy { + match bty { + BaseTy::Adt(adt_def, args) if adt_def.is_box() => { + let (boxed, alloc) = box_args(args); + let args = List::from_arr([ + GenericArg::Ty(boxed.fold_with(self)), + GenericArg::Ty(alloc.clone()), + ]); + BaseTy::Adt(adt_def.clone(), args) + } + BaseTy::Ref(re, ty, Mutability::Not) => { + BaseTy::Ref(*re, ty.fold_with(self), Mutability::Not) + } + BaseTy::Tuple(tys) => BaseTy::Tuple(tys.fold_with(self)), + _ => bty.clone(), + } + } +} + +impl Ty { + pub fn shallow_canonicalize(&self) -> CanonicalTy { + let mut hoister = ShallowHoister::default(); + let ty = hoister.hoist(self); + let (vars, preds) = hoister.into_parts(); + let pred = Expr::and(preds); + let constr_ty = CanonicalConstrTy { ty, pred }; + if vars.is_empty() { + CanonicalTy::Constr(constr_ty) + } else { + CanonicalTy::Exists(Binder::new(constr_ty, vars)) + } + } +} + +pub struct CanonicalConstrTy { + /// Guranteed to not have any (shallow) [existential] or [constraint] types + /// + /// [existential]: TyKind::Exists + /// [constraint]: TyKind::Constr + ty: Ty, + pred: Expr, +} + +impl CanonicalConstrTy { + pub fn ty(&self) -> Ty { + self.ty.clone() + } + + pub fn pred(&self) -> Expr { + self.pred.clone() + } +} + +/// A (shallowly) canonicalized type. This can be either of the form `{T | p}` or `∃v0,…,vn. {T | p}`, +/// where `T` doesnt have any (shallow) [existential] or [constraint] types. +/// +/// When canonizalizing a type without a [constraint] type, `p` will be [`Expr::tt()`]. +/// +/// [existential]: TyKind::Exists +/// [constraint]: TyKind::Constr +pub enum CanonicalTy { + /// A type of the form `{T | p}` + Constr(CanonicalConstrTy), + /// A type of the form `∃v0,…,vn. {T | p}` + Exists(Binder), +} + +impl CanonicalTy { + pub fn to_subset_ty_ctor(&self) -> Option { + match self { + CanonicalTy::Constr(constr) => { + if let TyKind::Indexed(bty, idx) = constr.ty.kind() { + // given {b[e] | p} return λv. {b[v] | p ∧ v == e} + let sort = bty.sort(); + let constr = SubsetTy::new( + bty.clone(), + Expr::nu(), + Expr::and([constr.pred.clone(), Expr::eq(Expr::nu(), idx)]), + ); + Some(Binder::with_sort(constr, sort)) + } else { + None + } + } + CanonicalTy::Exists(poly_constr) => { + let constr = poly_constr.as_ref().skip_binder(); + if let TyKind::Indexed(bty, idx) = constr.ty.kind() + && idx.is_nu() + { + let ctor = poly_constr + .as_ref() + .map(|constr| SubsetTy::new(bty.clone(), Expr::nu(), &constr.pred)); + Some(ctor) + } else { + None + } + } + } + } +} + +mod pretty { + use super::*; + use crate::pretty::*; + + impl Pretty for CanonicalConstrTy { + fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + define_scoped!(cx, f); + w!("{{ {:?} | {:?} }}", &self.ty, &self.pred) + } + } + + impl Pretty for CanonicalTy { + fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + define_scoped!(cx, f); + match self { + CanonicalTy::Constr(constr) => w!("{:?}", constr), + CanonicalTy::Exists(poly_constr) => { + cx.with_bound_vars(poly_constr.vars(), || { + cx.fmt_bound_vars("∃", poly_constr.vars(), ". ", f)?; + w!("{:?}", poly_constr.as_ref().skip_binder()) + }) + } + } + } + } + + impl_debug_with_default_cx!(CanonicalTy, CanonicalConstrTy); +} diff --git a/crates/flux-middle/src/rty/evars.rs b/crates/flux-middle/src/rty/evars.rs index 39ddb4175c..aa0b1973ff 100644 --- a/crates/flux-middle/src/rty/evars.rs +++ b/crates/flux-middle/src/rty/evars.rs @@ -149,7 +149,7 @@ mod pretty { impl Pretty for EVar { fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result { define_scoped!(cx, f); - w!("?e{}#{}", ^self.id.as_u32(), ^self.cx.0) + w!("?{}e#{}", ^self.id.as_u32(), ^self.cx.0) } } diff --git a/crates/flux-middle/src/rty/expr.rs b/crates/flux-middle/src/rty/expr.rs index f39d1814fd..65dcec5ec9 100644 --- a/crates/flux-middle/src/rty/expr.rs +++ b/crates/flux-middle/src/rty/expr.rs @@ -6,13 +6,13 @@ use itertools::Itertools; use rustc_hir::def_id::DefId; use rustc_index::newtype_index; use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable}; -use rustc_middle::mir::Local; +use rustc_middle::{mir::Local, ty::TyCtxt}; use rustc_span::{BytePos, Span, Symbol, SyntaxContext}; use rustc_target::abi::FieldIdx; use rustc_type_ir::{DebruijnIndex, INNERMOST}; use super::{ - evars::EVar, AliasReft, BaseTy, Binder, BoundReftKind, BoundVariableKind, FuncSort, IntTy, + evars::EVar, BaseTy, Binder, BoundReftKind, BoundVariableKind, FuncSort, GenericArgs, IntTy, Sort, UintTy, }; use crate::{ @@ -57,6 +57,24 @@ impl Lambda { } } +#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)] +pub struct AliasReft { + pub trait_id: DefId, + pub name: Symbol, + pub args: GenericArgs, +} + +impl AliasReft { + pub fn to_rustc_trait_ref<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::TraitRef<'tcx> { + let trait_def_id = self.trait_id; + let args = self + .args + .to_rustc(tcx) + .truncate_to(tcx, tcx.generics_of(trait_def_id)); + rustc_middle::ty::TraitRef::new(tcx, trait_def_id, args) + } +} + pub type Expr = Interned; #[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)] @@ -358,6 +376,16 @@ impl Expr { Expr::late_bvar(INNERMOST, 0, BoundReftKind::Annon) } + pub fn is_nu(&self) -> bool { + if let ExprKind::Var(Var::LateBound(INNERMOST, var)) = self.kind() + && var.index == 0 + { + true + } else { + false + } + } + #[track_caller] pub fn expect_adt(&self) -> (DefId, List) { if let ExprKind::Aggregate(AggregateKind::Adt(def_id), flds) = self.kind() { @@ -548,8 +576,7 @@ impl Expr { } /// Simple syntactic check to see if the expression is a trivially true predicate. This is used - /// mostly for filtering predicates when pretty printing but also to simplify the constraint - /// before encoding it into fixpoint. + /// mostly for filtering predicates when pretty printing but also to simplify types in general. pub fn is_trivially_true(&self) -> bool { self.is_true() || matches!(self.kind(), ExprKind::BinaryOp(BinOp::Eq | BinOp::Iff | BinOp::Imp, e1, e2) if e1 == e2) @@ -797,6 +824,12 @@ impl From for Expr { } } +impl From for Expr { + fn from(var: Var) -> Self { + Expr::var(var, None) + } +} + impl From for Path { fn from(loc: Loc) -> Self { Path::new(loc, vec![]) @@ -956,7 +989,7 @@ mod pretty { impl Pretty for AliasReft { fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result { define_scoped!(cx, f); - w!("<{:?} as {:?}", &self.args[0], self.trait_id)?; + w!("<({:?}) as {:?}", &self.args[0], self.trait_id)?; let args = &self.args[1..]; if !args.is_empty() { w!("<{:?}>", join!(", ", args))?; diff --git a/crates/flux-middle/src/rty/fold.rs b/crates/flux-middle/src/rty/fold.rs index 0a1dbde83c..c031b4e4c5 100644 --- a/crates/flux-middle/src/rty/fold.rs +++ b/crates/flux-middle/src/rty/fold.rs @@ -17,7 +17,7 @@ use super::{ AliasReft, AliasTy, BaseTy, BinOp, Binder, BoundVariableKind, Clause, ClauseKind, Constraint, CoroutineObligPredicate, Expr, ExprKind, FnOutput, FnSig, FnTraitPredicate, FuncSort, GenericArg, Invariant, KVar, Lambda, Name, OpaqueArgsMap, Opaqueness, OutlivesPredicate, - PolyFuncSort, ProjectionPredicate, PtrKind, Qualifier, ReLateBound, Region, Sort, + PolyFuncSort, ProjectionPredicate, PtrKind, Qualifier, ReLateBound, Region, Sort, SubsetTy, TraitPredicate, TraitRef, Ty, TyKind, }; use crate::{ @@ -77,6 +77,10 @@ pub trait FallibleTypeFolder: Sized { bty.try_super_fold_with(self) } + fn try_fold_subset_ty(&mut self, constr: &SubsetTy) -> Result { + constr.try_super_fold_with(self) + } + fn try_fold_region(&mut self, re: &Region) -> Result { Ok(*re) } @@ -103,6 +107,10 @@ pub trait TypeFolder: FallibleTypeFolder { bty.super_fold_with(self) } + fn fold_subset_ty(&mut self, constr: &SubsetTy) -> SubsetTy { + constr.super_fold_with(self) + } + fn fold_region(&mut self, re: &Region) -> Region { *re } @@ -137,6 +145,10 @@ where Ok(self.fold_bty(bty)) } + fn try_fold_subset_ty(&mut self, ty: &SubsetTy) -> Result { + Ok(self.fold_subset_ty(ty)) + } + fn try_fold_region(&mut self, re: &Region) -> Result { Ok(self.fold_region(re)) } @@ -967,11 +979,35 @@ impl TypeFoldable for AliasTy { } } +impl TypeVisitable for SubsetTy { + fn visit_with(&self, visitor: &mut V) -> ControlFlow { + self.bty.visit_with(visitor)?; + self.idx.visit_with(visitor)?; + self.pred.visit_with(visitor) + } +} + +impl TypeFoldable for SubsetTy { + fn try_fold_with(&self, folder: &mut F) -> Result { + folder.try_fold_subset_ty(self) + } +} + +impl TypeSuperFoldable for SubsetTy { + fn try_super_fold_with(&self, folder: &mut F) -> Result { + Ok(SubsetTy { + bty: self.bty.try_fold_with(folder)?, + idx: self.idx.try_fold_with(folder)?, + pred: self.pred.try_fold_with(folder)?, + }) + } +} + impl TypeVisitable for GenericArg { fn visit_with(&self, visitor: &mut V) -> ControlFlow { match self { GenericArg::Ty(ty) => ty.visit_with(visitor), - GenericArg::BaseTy(ty) => ty.visit_with(visitor), + GenericArg::Base(ty) => ty.visit_with(visitor), GenericArg::Lifetime(_) => ControlFlow::Continue(()), GenericArg::Const(_) => ControlFlow::Continue(()), } @@ -982,7 +1018,7 @@ impl TypeFoldable for GenericArg { fn try_fold_with(&self, folder: &mut F) -> Result { let arg = match self { GenericArg::Ty(ty) => GenericArg::Ty(ty.try_fold_with(folder)?), - GenericArg::BaseTy(ty) => GenericArg::BaseTy(ty.try_fold_with(folder)?), + GenericArg::Base(sty) => GenericArg::Base(sty.try_fold_with(folder)?), GenericArg::Lifetime(re) => GenericArg::Lifetime(re.try_fold_with(folder)?), GenericArg::Const(c) => GenericArg::Const(c.clone()), }; @@ -1076,9 +1112,8 @@ impl TypeVisitable for Var { impl TypeFoldable for AliasReft { fn try_fold_with(&self, folder: &mut F) -> Result { let trait_id = self.trait_id; - let generic_args = self.args.try_fold_with(folder)?; - let alias_pred = AliasReft { trait_id, name: self.name, args: generic_args }; - Ok(alias_pred) + let args = self.args.try_fold_with(folder)?; + Ok(AliasReft { trait_id, name: self.name, args }) } } diff --git a/crates/flux-middle/src/rty/mod.rs b/crates/flux-middle/src/rty/mod.rs index 1aeda1c7fa..688755cd62 100644 --- a/crates/flux-middle/src/rty/mod.rs +++ b/crates/flux-middle/src/rty/mod.rs @@ -4,6 +4,7 @@ //! //! * Types in this module use debruijn indices to represent local binders. //! * Data structures are interned so they can be cheaply cloned. +pub mod canonicalize; pub mod evars; mod expr; pub mod fold; @@ -16,8 +17,8 @@ use std::{borrow::Cow, fmt, hash::Hash, iter, slice, sync::LazyLock}; pub use evars::{EVar, EVarGen}; pub use expr::{ - AggregateKind, BinOp, BoundReft, Constant, ESpan, Expr, ExprKind, FieldProj, HoleKind, KVar, - KVid, Lambda, Loc, Name, Path, UnOp, Var, + AggregateKind, AliasReft, BinOp, BoundReft, Constant, ESpan, Expr, ExprKind, FieldProj, + HoleKind, KVar, KVid, Lambda, Loc, Name, Path, UnOp, Var, }; use flux_common::bug; use itertools::Itertools; @@ -151,8 +152,7 @@ pub struct GenericParamDef { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum GenericParamDefKind { Type { has_default: bool }, - SplTy, - BaseTy, + Base, Lifetime, Const { has_default: bool }, } @@ -553,7 +553,26 @@ pub struct ClosureOblig { pub oblig_sig: PolyFnSig, } -pub type PolyTy = Binder; +pub type TyCtor = Binder; + +impl TyCtor { + pub fn to_ty(&self) -> Ty { + match &self.vars[..] { + [] => return self.value.shift_out_escaping(1), + [BoundVariableKind::Refine(sort, ..)] => { + if sort.is_unit() { + return self.replace_bound_reft(&Expr::unit()); + } + if let Some(def_id) = sort.is_unit_adt() { + return self.replace_bound_reft(&Expr::unit_adt(def_id)); + } + } + _ => {} + } + Ty::exists(self.clone()) + } +} + pub type Ty = Interned; impl Ty { @@ -710,6 +729,57 @@ impl Ty { | TyKind::Blocked(_) => todo!(), } } + + /// Whether the type is an `int` or a `uint` + pub fn is_integral(&self) -> bool { + self.as_bty_skipping_existentials() + .map(BaseTy::is_integral) + .unwrap_or_default() + } + + /// Whether the type is a `bool` + pub fn is_bool(&self) -> bool { + self.as_bty_skipping_existentials() + .map(BaseTy::is_bool) + .unwrap_or_default() + } + + pub fn is_uninit(&self) -> bool { + matches!(self.kind(), TyKind::Uninit) + } + + pub fn is_box(&self) -> bool { + self.as_bty_skipping_existentials() + .map(BaseTy::is_box) + .unwrap_or_default() + } + + pub fn is_struct(&self) -> bool { + self.as_bty_skipping_existentials() + .map(BaseTy::is_struct) + .unwrap_or_default() + } + + pub fn is_array(&self) -> bool { + self.as_bty_skipping_existentials() + .map(BaseTy::is_array) + .unwrap_or_default() + } + + pub fn is_slice(&self) -> bool { + self.as_bty_skipping_existentials() + .map(BaseTy::is_slice) + .unwrap_or_default() + } + + pub fn as_bty_skipping_existentials(&self) -> Option<&BaseTy> { + match self.kind() { + TyKind::Indexed(bty, _) => Some(bty), + TyKind::Exists(ty) => Some(ty.as_ref().skip_binder().as_bty_skipping_existentials()?), + TyKind::Constr(_, ty) => ty.as_bty_skipping_existentials(), + _ => None, + } + } } #[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)] @@ -717,24 +787,6 @@ pub struct TyS { kind: TyKind, } -#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)] -pub struct AliasReft { - pub trait_id: DefId, - pub name: Symbol, - pub args: GenericArgs, -} - -impl AliasReft { - pub fn to_rustc_trait_ref<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::TraitRef<'tcx> { - let trait_def_id = self.trait_id; - let args = self - .args - .to_rustc(tcx) - .truncate_to(tcx, tcx.generics_of(trait_def_id)); - rustc_middle::ty::TraitRef::new(tcx, trait_def_id, args) - } -} - #[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, Debug)] pub enum TyKind { Indexed(BaseTy, Expr), @@ -777,7 +829,7 @@ pub enum BaseTy { Tuple(List), Array(Ty, Const), Never, - Closure(DefId, List), + Closure(DefId, /* upvar_tys */ List), Coroutine(DefId, /*resume_ty: */ Ty, /* upvar_tys: */ List), Param(ParamTy), } @@ -810,10 +862,106 @@ pub type RefineArgs = List; pub type OpaqueArgsMap = FxHashMap; +/// A type constructor meant to be used as generic a argument of [kind base]. This is just an alias +/// to [`Binder`], but we expect the binder to have a single bound variable of the sort of +/// the underlying [`BaseTy`]. +/// +/// [kind base]: GenericParamDefKind::Base +pub type SubsetTyCtor = Binder; + +impl SubsetTyCtor { + pub fn as_bty_skipping_binder(&self) -> &BaseTy { + &self.as_ref().skip_binder().bty + } + + pub fn to_ty(&self) -> Ty { + let sort = self.sort(); + if sort.is_unit() { + self.replace_bound_reft(&Expr::unit()).to_ty() + } else if let Some(def_id) = sort.is_unit_adt() { + self.replace_bound_reft(&Expr::unit_adt(def_id)).to_ty() + } else { + Ty::exists(self.as_ref().map(SubsetTy::to_ty)) + } + } +} + +/// A subset type is a simplified version of a type that has the form `{b[e] | p}` where `b` is a +/// [`BaseTy`], `e` a refinement index, and `p` a predicate. These are mainly found under a [`Binder`] +/// with a single variable of the base type's sort. This can be interpreted as a type constructor or +/// an existial type. For example, under a binder with a variable `v` of sort `int`, we can interpret +/// `{i32[v] | v > 0}` as a lambda `λv:int. {i32[v] | v > 0}` that "constructs" types when applied to +/// ints, or as an existential type `∃v:int. {i32[v] | v > 0}`. This second interpretation is the +/// reason we call this a subset type, i.e., the type `∃v. {b[v] | p}` corresponds to the subset of +/// values of (base) type `b` whose index satisfies `p`. In other words, these are the types supported +/// by liquid haskell (with the difference that we are explicit about separating refinements from +/// program values via an index). +/// +/// The main purpose for a [`SubsetTy`] is to be used as generic arguments of [kind base] when +/// interpreted as a type contructor. The key property of a [`SubsetTy`] is that it can be eagerly +/// canonicalized via [*strengthening*] during substitution. For example, suppose we have a function: +/// ```text +/// fn foo(x: T[@a], y: { T[@b] | b == a }) { } +/// ``` +/// If we instantiate `T` with `λv. { i32[v] | v > 0}`, after substitution and applying the lambda, +/// we get: +/// ```text +/// fn foo(x: {i32[@a] | a > 0}, y: { { i32[@b] | b > 0 } | b == a }) { } +/// ``` +/// By the strengthening rule we can canonicalize this to +/// ```text +/// fn foo(x: {i32[@a] | a > 0}, y: { i32[@b] | b == a && b > 0 }) { } +/// ``` +/// As a result, we can guarantee a simple canonical form that makes it easier to manipulate types +/// syntactically. +/// +/// [kind base]: GenericParamDefKind::Base +/// [*strengthening*]: https://arxiv.org/pdf/2010.07763.pdf +#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)] +pub struct SubsetTy { + /// **NOTE:** This [`BaseTy`] is mainly going to be under a [`Binder`]. It is not yet clear whether + /// this [`BaseTy`] should be able to mention variables in the binder. In general, in a type + /// `∃v. {b[e] | p}`, it's fine to mention `v` inside `b`, but since [`SubsetTy`] is meant to + /// facilitate syntatic manipulation we may restrict this. + pub bty: BaseTy, + /// This can be an arbitrary expression which makes the syntatic manipulation easier, but since + /// this is mostly going to be under a binder we expect it to be [`Expr::nu()`]. + pub idx: Expr, + pub pred: Expr, +} + +impl SubsetTy { + pub fn new(bty: BaseTy, idx: impl Into, pred: impl Into) -> Self { + Self { bty, idx: idx.into(), pred: pred.into() } + } + + pub fn trivial(bty: BaseTy, idx: impl Into) -> Self { + Self::new(bty, idx, Expr::tt()) + } + + pub fn strengthen(&self, pred: impl Into) -> Self { + let this = self.clone(); + Self { bty: this.bty, idx: this.idx, pred: Expr::and([this.pred, pred.into()]) } + } + + fn to_rustc<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::Ty<'tcx> { + self.bty.to_rustc(tcx) + } + + fn to_ty(&self) -> Ty { + let bty = self.bty.clone(); + if self.pred.is_trivially_true() { + Ty::indexed(bty, &self.idx) + } else { + Ty::constr(&self.pred, Ty::indexed(bty, &self.idx)) + } + } +} + #[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)] pub enum GenericArg { Ty(Ty), - BaseTy(Binder), + Base(SubsetTyCtor), Lifetime(Region), Const(Const), } @@ -830,7 +978,7 @@ impl GenericArg { pub fn peel_out_sort(&self) -> Option { match self { GenericArg::Ty(ty) => ty.as_bty_skipping_existentials().map(BaseTy::sort), - GenericArg::BaseTy(abs) => Some(abs.vars()[0].expect_sort().clone()), + GenericArg::Base(ctor) => Some(ctor.sort()), GenericArg::Lifetime(_) | GenericArg::Const(_) => None, } } @@ -838,14 +986,14 @@ impl GenericArg { pub fn is_valid_base_arg(&self) -> bool { match self { GenericArg::Ty(ty) => ty.kind().is_valid_base_ty(), - GenericArg::BaseTy(bty) => bty.as_ref().skip_binder().kind().is_valid_base_ty(), + GenericArg::Base(_) => true, _ => false, } } fn from_param_def(genv: GlobalEnv, param: &GenericParamDef) -> QueryResult { match param.kind { - GenericParamDefKind::Type { .. } | GenericParamDefKind::SplTy => { + GenericParamDefKind::Type { .. } => { let param_ty = ParamTy { index: param.index, name: param.name }; Ok(GenericArg::Ty(Ty::param(param_ty))) } @@ -860,7 +1008,7 @@ impl GenericArg { let ty = genv.lower_type_of(param.def_id)?.skip_binder(); Ok(GenericArg::Const(Const { kind, ty })) } - GenericParamDefKind::BaseTy => { + GenericParamDefKind::Base => { bug!("") } } @@ -870,8 +1018,8 @@ impl GenericArg { use rustc_middle::ty; match self { GenericArg::Ty(ty) => ty::GenericArg::from(ty.to_rustc(tcx)), - GenericArg::BaseTy(bty) => { - ty::GenericArg::from(bty.as_ref().skip_binder().to_rustc(tcx)) + GenericArg::Base(ctor) => { + ty::GenericArg::from(ctor.as_ref().skip_binder().to_rustc(tcx)) } GenericArg::Lifetime(re) => ty::GenericArg::from(re.to_rustc(tcx)), GenericArg::Const(_) => todo!(), @@ -1174,6 +1322,14 @@ impl Binder { pub fn try_map(self, f: impl FnOnce(T) -> Result) -> Result, E> { Ok(Binder { vars: self.vars, value: f(self.value)? }) } + + #[track_caller] + pub fn sort(&self) -> Sort { + match &self.vars[..] { + [BoundVariableKind::Refine(sort, ..)] => sort.clone(), + _ => bug!("expected single-sorted binder"), + } + } } impl List { @@ -1275,16 +1431,6 @@ where } } -impl Binder { - pub fn into_ty(self) -> Ty { - if self.vars.is_empty() { - self.value - } else { - Ty::exists(self) - } - } -} - impl EarlyBinder { pub fn instantiate(self, args: &[GenericArg], refine_args: &[Expr]) -> T { self.0.fold_with(&mut subst::GenericsSubstFolder::new( @@ -1561,62 +1707,12 @@ impl TyS { } } + #[track_caller] pub(crate) fn expect_tuple(&self) -> &[Ty] { if let TyKind::Indexed(BaseTy::Tuple(tys), _) = self.kind() { tys } else { - bug!("expected adt") - } - } - - /// Whether the type is an `int` or a `uint` - pub fn is_integral(&self) -> bool { - self.as_bty_skipping_existentials() - .map(BaseTy::is_integral) - .unwrap_or_default() - } - - /// Whether the type is a `bool` - pub fn is_bool(&self) -> bool { - self.as_bty_skipping_existentials() - .map(BaseTy::is_bool) - .unwrap_or_default() - } - - pub fn is_uninit(&self) -> bool { - matches!(self.kind(), TyKind::Uninit) - } - - pub fn is_box(&self) -> bool { - self.as_bty_skipping_existentials() - .map(BaseTy::is_box) - .unwrap_or_default() - } - - pub fn is_struct(&self) -> bool { - self.as_bty_skipping_existentials() - .map(BaseTy::is_struct) - .unwrap_or_default() - } - - pub fn is_array(&self) -> bool { - self.as_bty_skipping_existentials() - .map(BaseTy::is_array) - .unwrap_or_default() - } - - pub fn is_slice(&self) -> bool { - self.as_bty_skipping_existentials() - .map(BaseTy::is_slice) - .unwrap_or_default() - } - - pub fn as_bty_skipping_existentials(&self) -> Option<&BaseTy> { - match self.kind() { - TyKind::Indexed(bty, _) => Some(bty), - TyKind::Exists(ty) => Some(ty.as_ref().skip_binder().as_bty_skipping_existentials()?), - TyKind::Constr(_, ty) => ty.as_bty_skipping_existentials(), - _ => None, + bug!("expected tuple found `{self:?}` (kind: `{:?}`)", self.kind()) } } } @@ -1665,6 +1761,10 @@ impl BaseTy { matches!(self, BaseTy::Slice(..)) } + fn is_adt(&self) -> bool { + matches!(self, BaseTy::Adt(..)) + } + pub fn is_box(&self) -> bool { matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_box()) } @@ -2028,7 +2128,7 @@ mod pretty { default fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result { define_scoped!(cx, f); cx.with_bound_vars(&self.vars, || { - cx.fmt_bound_vars("for<", &self.vars, ">", f)?; + cx.fmt_bound_vars("for<", &self.vars, "> ", f)?; w!("{:?}", &self.value) }) } @@ -2046,7 +2146,7 @@ mod pretty { let vars = &self.vars; cx.with_bound_vars(vars, || { if !vars.is_empty() { - cx.fmt_bound_vars("for<", vars, ">", f)?; + cx.fmt_bound_vars("for<", vars, "> ", f)?; } w!("{:?}", &self.value) }) @@ -2162,16 +2262,25 @@ mod pretty { define_scoped!(cx, f); let vars = &self.vars; cx.with_bound_vars(vars, || { - cx.fmt_bound_vars("exists<", vars, ">", f)?; - w!("{:?}", &self.value.ret)?; - if !self.value.ensures.is_empty() { - w!("; [{:?}]", join!(", ", &self.value.ensures))?; + if !vars.is_empty() { + cx.fmt_bound_vars("exists<", vars, "> ", f)?; } - Ok(()) + w!("{:?}", &self.value) }) } } + impl Pretty for FnOutput { + fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result { + define_scoped!(cx, f); + w!("{:?}", &self.ret)?; + if !self.ensures.is_empty() { + w!("; [{:?}]", join!(", ", &self.ensures))?; + } + Ok(()) + } + } + impl Pretty for Constraint { fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result { define_scoped!(cx, f); @@ -2182,6 +2291,17 @@ mod pretty { } } + impl Pretty for SubsetTy { + fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result { + define_scoped!(cx, f); + if self.pred.is_trivially_true() { + w!("{:?}[{:?}]", &self.bty, &self.idx) + } else { + w!("{{ {:?}[{:?}] | {:?} }}", &self.bty, &self.idx, &self.pred) + } + } + } + impl Pretty for TyS { fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result { define_scoped!(cx, f); @@ -2192,7 +2312,9 @@ mod pretty { return Ok(()); } if idx.is_unit() { - w!("[]")?; + if bty.is_adt() { + w!("[]")?; + } } else { w!("[{:?}]", idx)?; } @@ -2365,11 +2487,11 @@ mod pretty { fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result { define_scoped!(cx, f); match self { - GenericArg::Ty(arg) => w!("{:?}", arg), - GenericArg::BaseTy(arg) => { - cx.with_bound_vars(arg.vars(), || { - cx.fmt_bound_vars("λ", arg.vars(), ". ", f)?; - w!("{:?}", arg.as_ref().skip_binder()) + GenericArg::Ty(ty) => w!("{:?}", ty), + GenericArg::Base(ctor) => { + cx.with_bound_vars(ctor.vars(), || { + cx.fmt_bound_vars("λ", ctor.vars(), ". ", f)?; + w!("{:?}", &ctor.value) }) } GenericArg::Lifetime(re) => w!("{:?}", re), @@ -2410,5 +2532,6 @@ mod pretty { PtrKind, FuncSort, SortCtor, + SubsetTy, ); } diff --git a/crates/flux-middle/src/rty/projections.rs b/crates/flux-middle/src/rty/projections.rs index ac22147ab6..df61e60a14 100644 --- a/crates/flux-middle/src/rty/projections.rs +++ b/crates/flux-middle/src/rty/projections.rs @@ -6,14 +6,14 @@ use rustc_hir::def_id::DefId; use rustc_infer::{infer::InferCtxt, traits::Obligation}; use rustc_middle::{ traits::{ImplSource, ObligationCause}, - ty::{EarlyBoundRegion, ParamTy, TyCtxt}, + ty::TyCtxt, }; use rustc_trait_selection::traits::SelectionContext; use super::{ fold::{FallibleTypeFolder, TypeSuperFoldable}, - AliasKind, AliasReft, AliasTy, BaseTy, Clause, ClauseKind, Expr, ExprKind, GenericArg, - GenericArgs, ProjectionPredicate, RefineArgs, Region, Ty, TyKind, + AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Expr, ExprKind, GenericArg, + ProjectionPredicate, RefineArgs, Region, SubsetTy, Ty, TyKind, }; use crate::{ global_env::GlobalEnv, @@ -49,14 +49,16 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { ) -> QueryResult { if let Some(impl_def_id) = self.impl_id_of_alias_reft(obligation)? { let impl_trait_ref = self - .tcx() - .impl_trait_ref(impl_def_id) + .genv + .impl_trait_ref(impl_def_id)? .unwrap() .skip_binder(); let generics = self.tcx().generics_of(impl_def_id); let mut subst = TVarSubst::new(generics); - subst.infer_from_args(impl_trait_ref.args, &obligation.args); + for (a, b) in iter::zip(&impl_trait_ref.args, &obligation.args) { + subst.generic_args(a, b); + } let args = subst.finish(self.tcx(), generics); let pred = self @@ -100,15 +102,17 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { // => {T -> {v. i32[v] | v > 0}, A -> Global} let impl_trait_ref = self - .tcx() - .impl_trait_ref(impl_def_id) + .genv + .impl_trait_ref(impl_def_id)? .unwrap() .skip_binder(); let generics = self.tcx().generics_of(impl_def_id); let mut subst = TVarSubst::new(generics); - subst.infer_from_args(impl_trait_ref.args, &obligation.args); + for (a, b) in iter::zip(&impl_trait_ref.args, &obligation.args) { + subst.generic_args(a, b); + } let args = subst.finish(self.tcx(), generics); // 2. Get the associated type in the impl block and apply the substitution to it @@ -124,7 +128,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { .genv .type_of(assoc_type_id)? .instantiate(&args, &[]) - .into_ty()) + .to_ty()) } } } @@ -146,7 +150,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { &self, obligation: &AliasTy, candidates: &mut Vec, - ) -> QueryResult<()> { + ) -> QueryResult { if let GenericArg::Ty(ty) = &obligation.args[0] && let TyKind::Alias(AliasKind::Opaque, alias_ty) = ty.kind() { @@ -283,96 +287,73 @@ impl TVarSubst { .collect() } - fn insert_param_ty(&mut self, pty: ParamTy, ty: &Ty) { - let arg = GenericArg::Ty(ty.clone()); - if self.args[pty.index as usize].replace(arg).is_some() { - bug!("duplicate insert"); - } - } - - fn insert_early_bound_region(&mut self, ebr: EarlyBoundRegion, re: Region) { - let arg = GenericArg::Lifetime(re); - if self.args[ebr.index as usize].replace(arg).is_some() { - bug!("duplicate insert"); - } - } - - fn infer_from_args(&mut self, src: rustc_middle::ty::GenericArgsRef, dst: &GenericArgs) { - debug_assert_eq!(src.len(), dst.len()); - for (src, dst) in iter::zip(src, dst) { - self.infer_from_arg(src, dst); + fn generic_args(&mut self, a: &GenericArg, b: &GenericArg) { + match (a, b) { + (GenericArg::Ty(a), GenericArg::Ty(b)) => self.tys(a, b), + (GenericArg::Lifetime(a), GenericArg::Lifetime(b)) => self.regions(*a, *b), + (GenericArg::Base(a), GenericArg::Base(b)) => { + self.btys(a.as_bty_skipping_binder(), b.as_bty_skipping_binder()); + } + _ => {} } } - fn infer_from_arg(&mut self, src: rustc_middle::ty::GenericArg, dst: &GenericArg) { - match dst { - GenericArg::Ty(dst) => { - self.infer_from_ty(&src.as_type().unwrap(), dst); - } - GenericArg::Lifetime(dst) => self.infer_from_region(&src.as_region().unwrap(), dst), - GenericArg::BaseTy(bty) => { - self.infer_from_ty(&src.as_type().unwrap(), &bty.clone().skip_binder()); + fn tys(&mut self, a: &Ty, b: &Ty) { + if let TyKind::Param(param_ty) = a.kind() { + if !b.has_escaping_bvars() { + self.insert_generic_arg(param_ty.index, GenericArg::Ty(b.clone())); } - _ => (), + return; } + let Some(a_bty) = a.as_bty_skipping_existentials() else { return }; + let Some(b_bty) = b.as_bty_skipping_existentials() else { return }; + self.btys(a_bty, b_bty); } - fn infer_from_ty(&mut self, src: &rustc_middle::ty::Ty, dst: &Ty) { - use rustc_middle::ty; - match src.kind() { - ty::TyKind::Param(pty) => self.insert_param_ty(*pty, dst), - ty::TyKind::Adt(_, src_subst) => { - // NOTE: see https://github.com/flux-rs/flux/pull/478#issuecomment-1650983695 - if let Some(dst) = dst.as_bty_skipping_existentials() - && !dst.has_escaping_bvars() - && let BaseTy::Adt(_, dst_subst) = dst - { - debug_assert_eq!(src_subst.len(), dst_subst.len()); - for (src_arg, dst_arg) in iter::zip(*src_subst, dst_subst) { - self.infer_from_arg(src_arg, dst_arg); - } - } else { - bug!("unexpected type {dst:?}"); + fn btys(&mut self, a: &BaseTy, b: &BaseTy) { + match (a, b) { + (BaseTy::Param(param_ty), _) => { + if !b.has_escaping_bvars() { + let sort = b.sort(); + let ctor = Binder::with_sort(SubsetTy::trivial(b.clone(), Expr::nu()), sort); + self.insert_generic_arg(param_ty.index, GenericArg::Base(ctor)); } } - ty::TyKind::Array(src, _) => { - if let Some(BaseTy::Array(dst, _)) = dst.as_bty_skipping_existentials() { - self.infer_from_ty(src, dst); - } else { - bug!("unexpected type {dst:?}"); + (BaseTy::Adt(_, a_args), BaseTy::Adt(_, b_args)) => { + debug_assert_eq!(a_args.len(), b_args.len()); + for (a_arg, b_arg) in iter::zip(a_args, b_args) { + self.generic_args(a_arg, b_arg); } } - ty::TyKind::Slice(src) => { - if let Some(BaseTy::Slice(dst)) = dst.as_bty_skipping_existentials() { - self.infer_from_ty(src, dst); - } else { - bug!("unexpected type {dst:?}"); - } + (BaseTy::Array(a_ty, _), BaseTy::Array(b_ty, _)) => { + self.tys(a_ty, b_ty); } - ty::TyKind::Tuple(src_tys) => { - if let Some(BaseTy::Tuple(dst_tys)) = dst.as_bty_skipping_existentials() { - debug_assert_eq!(src_tys.len(), dst_tys.len()); - iter::zip(src_tys.iter(), dst_tys.iter()) - .for_each(|(src, dst)| self.infer_from_ty(&src, dst)); - } else { - bug!("unexpected type {dst:?}"); + (BaseTy::Tuple(a_tys), BaseTy::Tuple(b_tys)) => { + debug_assert_eq!(a_tys.len(), b_tys.len()); + for (a_ty, b_ty) in iter::zip(a_tys, b_tys) { + self.tys(a_ty, b_ty); } } - ty::TyKind::Ref(src_re, src_ty, _) => { - if let Some(BaseTy::Ref(dst_re, dst_ty, _)) = dst.as_bty_skipping_existentials() { - self.infer_from_region(src_re, dst_re); - self.infer_from_ty(src_ty, dst_ty); - } else { - bug!("unexpected type {dst:?}"); - } + (BaseTy::Ref(a_re, a_ty, _), BaseTy::Ref(b_re, b_ty, _)) => { + self.regions(*a_re, *b_re); + self.tys(a_ty, b_ty); + } + (BaseTy::Slice(a_ty), BaseTy::Slice(b_ty)) => { + self.tys(a_ty, b_ty); } _ => {} } } - fn infer_from_region(&mut self, src: &rustc_middle::ty::Region, dst: &Region) { - if let rustc_middle::ty::RegionKind::ReEarlyBound(ebr) = src.kind() { - self.insert_early_bound_region(ebr, *dst); + fn regions(&mut self, a: Region, b: Region) { + if let Region::ReEarlyBound(ebr) = a { + self.insert_generic_arg(ebr.index, GenericArg::Lifetime(b)); + } + } + + fn insert_generic_arg(&mut self, idx: u32, arg: GenericArg) { + if self.args[idx as usize].replace(arg).is_some() { + bug!("duplicate insert"); } } } diff --git a/crates/flux-middle/src/rty/refining.rs b/crates/flux-middle/src/rty/refining.rs index c584d383e0..1655b5f9af 100644 --- a/crates/flux-middle/src/rty/refining.rs +++ b/crates/flux-middle/src/rty/refining.rs @@ -7,7 +7,12 @@ use rustc_hir::def_id::DefId; use rustc_middle::ty::{ClosureKind, ParamTy}; use super::fold::TypeFoldable; -use crate::{global_env::GlobalEnv, intern::List, queries::QueryResult, rty, rustc}; +use crate::{ + global_env::GlobalEnv, + intern::List, + queries::{QueryErr, QueryResult}, + rty, rustc, +}; pub(crate) fn refine_generics(generics: &rustc::ty::Generics) -> QueryResult { let params = generics @@ -38,14 +43,14 @@ pub(crate) fn refine_generics(generics: &rustc::ty::Generics) -> QueryResult { genv: GlobalEnv<'genv, 'tcx>, generics: rty::Generics, - refine: fn(rty::BaseTy) -> rty::Binder, + refine: fn(rty::BaseTy) -> rty::Binder, } impl<'genv, 'tcx> Refiner<'genv, 'tcx> { pub fn new( genv: GlobalEnv<'genv, 'tcx>, generics: &rty::Generics, - refine: fn(rty::BaseTy) -> rty::Binder, + refine: fn(rty::BaseTy) -> rty::Binder, ) -> Self { Self { genv, generics: generics.clone(), refine } } @@ -60,8 +65,11 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { generics: generics.clone(), refine: |bty| { let sort = bty.sort(); - let indexed = rty::Ty::indexed(bty.shift_in_escaping(1), rty::Expr::nu()); - let constr = rty::Ty::constr(rty::Expr::hole(rty::HoleKind::Pred), indexed); + let constr = rty::SubsetTy::new( + bty.shift_in_escaping(1), + rty::Expr::nu(), + rty::Expr::hole(rty::HoleKind::Pred), + ); rty::Binder::with_sort(constr, sort) }, } @@ -113,12 +121,12 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { &rustc::ty::AliasKind::Projection, &proj_pred.projection_ty, )?, - term: self.as_default().refine_ty(&proj_pred.term)?, + term: self.refine_ty(&proj_pred.term)?, }; rty::ClauseKind::Projection(pred) } rustc::ty::ClauseKind::TypeOutlives(pred) => { - let pred = rty::OutlivesPredicate(self.as_default().refine_ty(&pred.0)?, pred.1); + let pred = rty::OutlivesPredicate(self.refine_ty(&pred.0)?, pred.1); rty::ClauseKind::TypeOutlives(pred) } }; @@ -153,14 +161,13 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { Ok(rty::ClauseKind::FnTrait(pred)) } - fn refine_trait_ref(&self, trait_ref: &rustc::ty::TraitRef) -> QueryResult { + pub(crate) fn refine_trait_ref( + &self, + trait_ref: &rustc::ty::TraitRef, + ) -> QueryResult { let trait_ref = rty::TraitRef { def_id: trait_ref.def_id, - args: trait_ref - .args - .iter() - .map(|arg| self.refine_generic_arg_raw(arg)) - .try_collect()?, + args: self.refine_generic_args(trait_ref.def_id, &trait_ref.args)?, }; Ok(trait_ref) } @@ -211,6 +218,20 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { }) } + fn refine_generic_args( + &self, + def_id: DefId, + args: &rustc::ty::GenericArgs, + ) -> QueryResult { + let generics = self.generics_of(def_id)?; + let mut result = vec![]; + for (idx, arg) in args.iter().enumerate() { + let param = generics.param_at(idx, self.genv)?; + result.push(self.refine_generic_arg(¶m, arg)?); + } + Ok(List::from_vec(result)) + } + pub fn refine_generic_arg( &self, param: &rty::GenericParamDef, @@ -220,11 +241,12 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { (rty::GenericParamDefKind::Type { .. }, rustc::ty::GenericArg::Ty(ty)) => { Ok(rty::GenericArg::Ty(self.refine_ty(ty)?)) } - (rty::GenericParamDefKind::SplTy, rustc::ty::GenericArg::Ty(ty)) => { - Ok(rty::GenericArg::Ty(self.refine_ty(ty)?)) - } - (rty::GenericParamDefKind::BaseTy, rustc::ty::GenericArg::Ty(ty)) => { - Ok(rty::GenericArg::BaseTy(self.refine_poly_ty(ty)?)) + (rty::GenericParamDefKind::Base, rustc::ty::GenericArg::Ty(ty)) => { + let TyOrBase::Base(contr) = self.refine_ty_inner(ty)? else { + let def_span = self.genv.tcx().def_span(param.def_id); + return Err(QueryErr::InvalidGenericArg { def_id: param.def_id, def_span }); + }; + Ok(rty::GenericArg::Base(contr)) } (rty::GenericParamDefKind::Lifetime, rustc::ty::GenericArg::Lifetime(re)) => { Ok(rty::GenericArg::Lifetime(*re)) @@ -236,26 +258,13 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { } } - pub(crate) fn refine_generic_arg_raw( - &self, - arg: &rustc::ty::GenericArg, - ) -> QueryResult { - match arg { - rustc::ty::GenericArg::Ty(ty) => Ok(rty::GenericArg::Ty(self.refine_ty(ty)?)), - rustc::ty::GenericArg::Lifetime(re) => Ok(rty::GenericArg::Lifetime(*re)), - rustc::ty::GenericArg::Const(c) => Ok(rty::GenericArg::Const(c.clone())), - } - } - - pub(crate) fn refine_alias_ty( + fn refine_alias_ty( &self, alias_kind: &rustc::ty::AliasKind, alias_ty: &rustc::ty::AliasTy, ) -> QueryResult { let def_id = alias_ty.def_id; - let args = self.iter_with_generics_of(def_id, &alias_ty.args, |param, arg| { - self.as_default().refine_generic_arg(param, arg) - })?; + let args = self.refine_generic_args(def_id, &alias_ty.args)?; let refine_args = self.refine_args_of(def_id, alias_kind)?; @@ -263,22 +272,15 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { Ok(res) } - pub(crate) fn refine_ty(&self, ty: &rustc::ty::Ty) -> QueryResult { - let poly_ty = self.refine_poly_ty(ty)?; - let ty = match &poly_ty.vars()[..] { - [] => poly_ty.skip_binder().shift_out_escaping(1), - [rty::BoundVariableKind::Refine(s, ..)] => { - if s.is_unit() { - poly_ty.replace_bound_reft(&rty::Expr::unit()) - } else if let Some(def_id) = s.is_unit_adt() { - poly_ty.replace_bound_reft(&rty::Expr::unit_adt(def_id)) - } else { - rty::Ty::exists(poly_ty) - } - } - _ => rty::Ty::exists(poly_ty), - }; - Ok(ty) + pub fn refine_ty(&self, ty: &rustc::ty::Ty) -> QueryResult { + Ok(self.refine_ty_inner(ty)?.into_ty()) + } + + pub fn refine_ty_ctor(&self, ty: &rustc::ty::Ty) -> QueryResult { + Ok(self + .refine_ty_inner(ty)? + .expect_simple() + .map(|constr| constr.to_ty())) } fn refine_alias_kind(kind: &rustc::ty::AliasKind) -> rty::AliasKind { @@ -288,7 +290,7 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { } } - pub fn refine_poly_ty(&self, ty: &rustc::ty::Ty) -> QueryResult { + fn refine_ty_inner(&self, ty: &rustc::ty::Ty) -> QueryResult { let bty = match ty.kind() { rustc::ty::TyKind::Closure(did, args) => { let args = args.as_closure(); @@ -325,25 +327,24 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { } rustc::ty::TyKind::Param(param_ty) => { match self.param(*param_ty)?.kind { - rty::GenericParamDefKind::Type { .. } | rty::GenericParamDefKind::SplTy => { - return Ok(rty::Binder::new(rty::Ty::param(*param_ty), List::empty())); + rty::GenericParamDefKind::Type { .. } => { + return Ok(TyOrBase::Ty(rty::Ty::param(*param_ty))); + } + rty::GenericParamDefKind::Base => rty::BaseTy::Param(*param_ty), + rty::GenericParamDefKind::Lifetime | rty::GenericParamDefKind::Const { .. } => { + bug!() } - rty::GenericParamDefKind::BaseTy => rty::BaseTy::Param(*param_ty), - rty::GenericParamDefKind::Lifetime => bug!(), - rty::GenericParamDefKind::Const { .. } => bug!(), } } rustc::ty::TyKind::Adt(adt_def, args) => { let adt_def = self.genv.adt_def(adt_def.did())?; - let args = self.iter_with_generics_of(adt_def.did(), args, |param, arg| { - self.refine_generic_arg(param, arg) - })?; + let args = self.refine_generic_args(adt_def.did(), args)?; rty::BaseTy::adt(adt_def, args) } rustc::ty::TyKind::Alias(alias_kind, alias_ty) => { let kind = Self::refine_alias_kind(alias_kind); - let alias_ty = self.refine_alias_ty(alias_kind, alias_ty)?; - return Ok(rty::Binder::new(rty::Ty::alias(kind, alias_ty), List::empty())); + let alias_ty = self.as_default().refine_alias_ty(alias_kind, alias_ty)?; + return Ok(TyOrBase::Ty(rty::Ty::alias(kind, alias_ty))); } rustc::ty::TyKind::Bool => rty::BaseTy::Bool, rustc::ty::TyKind::Int(int_ty) => rty::BaseTy::Int(*int_ty), @@ -356,7 +357,7 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { rty::BaseTy::RawPtr(self.as_default().refine_ty(ty)?, *mu) } }; - Ok((self.refine)(bty)) + Ok(TyOrBase::Base((self.refine)(bty))) } fn as_default(&self) -> Self { @@ -387,39 +388,38 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { } } - fn iter_with_generics_of( - &self, - def_id: DefId, - args: &[rustc::ty::GenericArg], - f: impl FnMut(&rty::GenericParamDef, &rustc::ty::GenericArg) -> QueryResult, - ) -> QueryResult { - let generics = self.generics_of(def_id)?; - self.iter_with_generic_params(&generics, args, f) + fn param(&self, param_ty: ParamTy) -> QueryResult { + self.generics.param_at(param_ty.index as usize, self.genv) } +} - fn iter_with_generic_params( - &self, - generics: &rty::Generics, - args: &[rustc::ty::GenericArg], - mut f: impl FnMut(&rty::GenericParamDef, &rustc::ty::GenericArg) -> QueryResult, - ) -> QueryResult { - args.iter() - .enumerate() - .map(|(idx, arg)| { - let param = generics.param_at(idx, self.genv)?; - f(¶m, arg) - }) - .try_collect() +enum TyOrBase { + Ty(rty::Ty), + Base(rty::SubsetTyCtor), +} + +impl TyOrBase { + fn into_ty(self) -> rty::Ty { + match self { + TyOrBase::Ty(ty) => ty, + TyOrBase::Base(ctor) => ctor.to_ty(), + } } - fn param(&self, param_ty: ParamTy) -> QueryResult { - self.generics.param_at(param_ty.index as usize, self.genv) + #[track_caller] + fn expect_simple(self) -> rty::Binder { + if let TyOrBase::Base(poly_constr) = self { + poly_constr + } else { + bug!("unexpected ty") + } } } -fn refine_default(bty: rty::BaseTy) -> rty::Binder { +fn refine_default(bty: rty::BaseTy) -> rty::Binder { let sort = bty.sort(); - rty::Binder::with_sort(rty::Ty::indexed(bty.shift_in_escaping(1), rty::Expr::nu()), sort) + let constr = rty::SubsetTy::trivial(bty.shift_in_escaping(1), rty::Expr::nu()); + rty::Binder::with_sort(constr, sort) } pub fn refine_bound_variables( diff --git a/crates/flux-middle/src/rty/subst.rs b/crates/flux-middle/src/rty/subst.rs index c7e6c280f7..45781a9d12 100644 --- a/crates/flux-middle/src/rty/subst.rs +++ b/crates/flux-middle/src/rty/subst.rs @@ -71,8 +71,8 @@ impl RegionSubst { debug_assert_eq!(args1.len(), args2.len()); for (arg1, arg2) in iter::zip(args1, args2) { match (arg1, arg2) { - (GenericArg::BaseTy(ty_con), ty::GenericArg::Ty(ty2)) => { - self.infer_from_ty(ty_con.as_ref().skip_binder(), ty2); + (GenericArg::Base(ctor1), ty::GenericArg::Ty(ty2)) => { + self.infer_from_bty(ctor1.as_bty_skipping_binder(), ty2); } (GenericArg::Ty(ty1), ty::GenericArg::Ty(ty2)) => { self.infer_from_ty(ty1, ty2); @@ -253,7 +253,7 @@ pub(crate) struct GenericsSubstFolder<'a, D> { trait GenericsSubstDelegate { fn sort_for_param(&mut self, param_ty: ParamTy) -> Sort; fn ty_for_param(&mut self, param_ty: ParamTy) -> Ty; - fn bty_for_param(&mut self, param_ty: ParamTy, idx: &Expr) -> Ty; + fn ctor_for_param(&mut self, param_ty: ParamTy) -> SubsetTyCtor; fn region_for_param(&mut self, ebr: EarlyBoundRegion) -> Region; } @@ -273,8 +273,11 @@ impl GenericsSubstDelegate for IdentitySubstDelegate { Ty::param(param_ty) } - fn bty_for_param(&mut self, param_ty: ParamTy, idx: &Expr) -> Ty { - Ty::indexed(BaseTy::Param(param_ty), idx.clone()) + fn ctor_for_param(&mut self, param_ty: ParamTy) -> SubsetTyCtor { + Binder::with_sort( + SubsetTy::trivial(BaseTy::Param(param_ty), Expr::nu()), + Sort::Param(param_ty), + ) } fn region_for_param(&mut self, ebr: EarlyBoundRegion) -> Region { @@ -301,9 +304,9 @@ impl GenericsSubstDelegate for GenericArgsDelegate<'_> { } } - fn bty_for_param(&mut self, param_ty: ParamTy, idx: &Expr) -> Ty { + fn ctor_for_param(&mut self, param_ty: ParamTy) -> SubsetTyCtor { match self.0.get(param_ty.index as usize) { - Some(GenericArg::BaseTy(arg)) => arg.replace_bound_reft(idx), + Some(GenericArg::Base(ctor)) => ctor.clone(), Some(arg) => { bug!("expected base type for generic parameter, found `{:?}`", arg) } @@ -350,7 +353,7 @@ where bug!("unexpected type param {param_ty:?}"); } - fn bty_for_param(&mut self, param_ty: ParamTy, _idx: &Expr) -> Ty { + fn ctor_for_param(&mut self, param_ty: ParamTy) -> SubsetTyCtor { bug!("unexpected base type param {param_ty:?}"); } @@ -386,12 +389,26 @@ impl TypeFolder for GenericsSubstFolder<'_, D> { TyKind::Param(param_ty) => self.delegate.ty_for_param(*param_ty), TyKind::Indexed(BaseTy::Param(param_ty), idx) => { let idx = idx.fold_with(self); - self.delegate.bty_for_param(*param_ty, &idx) + self.delegate + .ctor_for_param(*param_ty) + .replace_bound_reft(&idx) + .to_ty() } _ => ty.super_fold_with(self), } } + fn fold_subset_ty(&mut self, constr: &SubsetTy) -> SubsetTy { + if let BaseTy::Param(param_ty) = &constr.bty { + self.delegate + .ctor_for_param(*param_ty) + .replace_bound_reft(&constr.idx) + .strengthen(&constr.pred) + } else { + constr.super_fold_with(self) + } + } + fn fold_region(&mut self, re: &Region) -> Region { if let ReEarlyBound(ebr) = *re { self.delegate.region_for_param(ebr) diff --git a/crates/flux-middle/src/rustc/lowering.rs b/crates/flux-middle/src/rustc/lowering.rs index 373d40db1f..611a14811d 100644 --- a/crates/flux-middle/src/rustc/lowering.rs +++ b/crates/flux-middle/src/rustc/lowering.rs @@ -865,10 +865,7 @@ fn lower_clause<'tcx>( let kind = match kind { rustc_ty::ClauseKind::Trait(trait_pred) => { ClauseKind::Trait(TraitPredicate { - trait_ref: TraitRef { - def_id: trait_pred.trait_ref.def_id, - args: lower_generic_args(tcx, trait_pred.trait_ref.args)?, - }, + trait_ref: lower_trait_ref(tcx, trait_pred.trait_ref)?, }) } rustc_ty::ClauseKind::Projection(proj_pred) => { @@ -894,6 +891,13 @@ fn lower_clause<'tcx>( Ok(Clause::new(kind)) } +pub(crate) fn lower_trait_ref<'tcx>( + tcx: TyCtxt<'tcx>, + trait_ref: rustc_ty::TraitRef<'tcx>, +) -> Result { + Ok(TraitRef { def_id: trait_ref.def_id, args: lower_generic_args(tcx, trait_ref.args)? }) +} + fn lower_type_outlives<'tcx>( tcx: TyCtxt<'tcx>, pred: rustc_ty::TypeOutlivesPredicate<'tcx>, diff --git a/crates/flux-middle/src/sort_of.rs b/crates/flux-middle/src/sort_of.rs index 03ef1cb6b0..280c098a38 100644 --- a/crates/flux-middle/src/sort_of.rs +++ b/crates/flux-middle/src/sort_of.rs @@ -57,9 +57,7 @@ impl<'sess, 'tcx> GlobalEnv<'sess, 'tcx> { fn sort_of_generic_param(self, def_id: LocalDefId) -> Option { let param = self.get_generic_param(def_id); match ¶m.kind { - fhir::GenericParamKind::BaseTy | fhir::GenericParamKind::SplTy => { - Some(rty::Sort::Param(self.def_id_to_param_ty(def_id))) - } + fhir::GenericParamKind::Base => Some(rty::Sort::Param(self.def_id_to_param_ty(def_id))), fhir::GenericParamKind::Type { .. } | fhir::GenericParamKind::Lifetime => None, } } @@ -68,9 +66,7 @@ impl<'sess, 'tcx> GlobalEnv<'sess, 'tcx> { let generics = self.map().get_generics(owner.expect_local()).unwrap(); let kind = generics.self_kind.as_ref()?; match kind { - fhir::GenericParamKind::BaseTy | fhir::GenericParamKind::SplTy => { - Some(rty::Sort::Param(rty::SELF_PARAM_TY)) - } + fhir::GenericParamKind::Base => Some(rty::Sort::Param(rty::SELF_PARAM_TY)), fhir::GenericParamKind::Type { .. } | fhir::GenericParamKind::Lifetime => None, } } @@ -82,7 +78,7 @@ impl<'sess, 'tcx> GlobalEnv<'sess, 'tcx> { } } - pub fn sort_of_ty(self, ty: &fhir::Ty) -> Option { + fn sort_of_ty(self, ty: &fhir::Ty) -> Option { match &ty.kind { fhir::TyKind::BaseTy(bty) | fhir::TyKind::Indexed(bty, _) => self.sort_of_bty(bty), fhir::TyKind::Exists(_, ty) | fhir::TyKind::Constr(_, ty) => self.sort_of_ty(ty), diff --git a/crates/flux-refineck/locales/en-US.ftl b/crates/flux-refineck/locales/en-US.ftl index 4909a54da7..35f93c1cc1 100644 --- a/crates/flux-refineck/locales/en-US.ftl +++ b/crates/flux-refineck/locales/en-US.ftl @@ -28,9 +28,6 @@ refineck_assert_error = refineck_param_inference_error = parameter inference error at function call -refineck_invalid_generic_arg = - cannot instantiate base or spl generic with opaque type - refineck_fold_error = type invariant may not hold (when place is folded) diff --git a/crates/flux-refineck/src/checker.rs b/crates/flux-refineck/src/checker.rs index be18d19565..2e68f23102 100644 --- a/crates/flux-refineck/src/checker.rs +++ b/crates/flux-refineck/src/checker.rs @@ -1147,11 +1147,13 @@ fn instantiate_args_for_fun_call( let refiner = Refiner::new(genv, caller_generics, |bty| { let sort = bty.sort(); - let mut ty = rty::Ty::indexed(bty.shift_in_escaping(1), rty::Expr::nu()); - if !sort.is_unit() { - ty = rty::Ty::constr(rty::Expr::hole(rty::HoleKind::Pred), ty); - } - rty::Binder::with_sort(ty, sort) + let bty = bty.shift_in_escaping(1); + let constr = if !sort.is_unit() { + rty::SubsetTy::new(bty, Expr::nu(), Expr::hole(rty::HoleKind::Pred)) + } else { + rty::SubsetTy::trivial(bty, Expr::nu()) + }; + Binder::with_sort(constr, sort) }); args.iter() @@ -1423,7 +1425,6 @@ pub(crate) mod errors { Inference, OpaqueStruct(DefId), Query(QueryErr), - InvalidGenericArg, } impl CheckerError { @@ -1445,12 +1446,6 @@ pub(crate) mod errors { flux_errors::diagnostic_id(), ) } - CheckerErrKind::InvalidGenericArg => { - handler.struct_err_with_code( - fluent::refineck_invalid_generic_arg, - flux_errors::diagnostic_id(), - ) - } CheckerErrKind::OpaqueStruct(def_id) => { let mut builder = handler.struct_err_with_code( fluent::refineck_opaque_struct_error, diff --git a/crates/flux-refineck/src/constraint_gen.rs b/crates/flux-refineck/src/constraint_gen.rs index 8263585924..4e62db02fd 100644 --- a/crates/flux-refineck/src/constraint_gen.rs +++ b/crates/flux-refineck/src/constraint_gen.rs @@ -9,9 +9,8 @@ use flux_middle::{ evars::{EVarCxId, EVarSol}, fold::TypeFoldable, AliasTy, BaseTy, BinOp, Binder, Constraint, CoroutineObligPredicate, ESpan, EVarGen, - EarlyBinder, Expr, ExprKind, FnOutput, GenericArg, GenericParamDefKind, HoleKind, - InferMode, Lambda, Mutability, Path, PolyFnSig, PolyVariant, PtrKind, Ref, Sort, Ty, - TyKind, Var, + EarlyBinder, Expr, ExprKind, FnOutput, GenericArg, HoleKind, InferMode, Lambda, Mutability, + Path, PolyFnSig, PolyVariant, PtrKind, Ref, Sort, Ty, TyKind, Var, }, rustc::mir::{BasicBlock, Place}, }; @@ -145,22 +144,6 @@ impl<'a, 'genv, 'tcx> ConstrGen<'a, 'genv, 'tcx> { .collect() } - fn check_generic_args(&self, did: DefId, generic_args: &[GenericArg]) -> Result { - let generics = self.genv.generics_of(did)?; - for (idx, arg) in generic_args.iter().enumerate() { - let param = generics.param_at(idx, self.genv)?; - match param.kind { - GenericParamDefKind::BaseTy => { - if !arg.is_valid_base_arg() { - return Err(CheckerErrKind::InvalidGenericArg); - } - } - _ => continue, - } - } - Ok(()) - } - #[allow(clippy::too_many_arguments)] pub(crate) fn check_fn_call( &mut self, @@ -174,10 +157,6 @@ impl<'a, 'genv, 'tcx> ConstrGen<'a, 'genv, 'tcx> { let genv = self.genv; let span = self.span; - if let Some(did) = callee_def_id { - self.check_generic_args(did, generic_args)?; - } - let mut infcx = self.infcx(rcx, ConstrReason::Call); let snapshot = rcx.snapshot(); @@ -681,23 +660,23 @@ impl<'a, 'genv, 'tcx> InferCtxt<'a, 'genv, 'tcx> { arg1: &GenericArg, arg2: &GenericArg, ) -> Result { - match (arg1, arg2) { - (GenericArg::Ty(ty1), GenericArg::Ty(ty2)) => { - match variance { - Variance::Covariant => self.subtyping(rcx, ty1, ty2), - Variance::Invariant => { - self.subtyping(rcx, ty1, ty2)?; - self.subtyping(rcx, ty2, ty1) - } - Variance::Contravariant => self.subtyping(rcx, ty2, ty1), - Variance::Bivariant => Ok(()), - } - } - (GenericArg::BaseTy(_), GenericArg::BaseTy(_)) => { - tracked_span_bug!("generic argument subtyping for base types is not implemented"); + let (ty1, ty2) = match (arg1, arg2) { + (GenericArg::Ty(ty1), GenericArg::Ty(ty2)) => (ty1.clone(), ty2.clone()), + (GenericArg::Base(ctor1), GenericArg::Base(ctor2)) => { + debug_assert_eq!(ctor1.sort(), ctor2.sort()); + (ctor1.to_ty(), ctor2.to_ty()) } - (GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => Ok(()), + (GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => return Ok(()), _ => tracked_span_bug!("incompatible generic args: `{arg1:?}` `{arg2:?}"), + }; + match variance { + Variance::Covariant => self.subtyping(rcx, &ty1, &ty2), + Variance::Invariant => { + self.subtyping(rcx, &ty1, &ty2)?; + self.subtyping(rcx, &ty2, &ty1) + } + Variance::Contravariant => self.subtyping(rcx, &ty2, &ty1), + Variance::Bivariant => Ok(()), } } diff --git a/crates/flux-refineck/src/fixpoint_encoding.rs b/crates/flux-refineck/src/fixpoint_encoding.rs index 8d4430af1a..1ec8e75a58 100644 --- a/crates/flux-refineck/src/fixpoint_encoding.rs +++ b/crates/flux-refineck/src/fixpoint_encoding.rs @@ -1093,7 +1093,7 @@ impl<'genv, 'tcx> ExprEncodingCtxt<'genv, 'tcx> { /// This function returns a very polymorphic sort for the UIF encoding the alias_pred; /// This is ok, as well-formedness in previous phases will ensure the function is always /// instantiated with the same sorts. However, the proper thing is to compute the *actual* -/// mono-sort at which this alias_pred is being used see [`GlobalEnv::sort_of_alias_pred`] but +/// mono-sort at which this alias_pred is being used see [`GlobalEnv::sort_of_alias_reft`] but /// that is a bit tedious as its done using the `fhir` (not `rty`). Alternatively, we might /// stash the computed mono-sort *in* the `rty::AliasPred` during `conv`? fn alias_reft_sort(arity: usize) -> rty::PolyFuncSort { diff --git a/crates/flux-refineck/src/refine_tree.rs b/crates/flux-refineck/src/refine_tree.rs index 8ef6df2ca2..f5a5516c5f 100644 --- a/crates/flux-refineck/src/refine_tree.rs +++ b/crates/flux-refineck/src/refine_tree.rs @@ -193,16 +193,17 @@ impl<'rcx> RefineCtxt<'rcx> { fresh } - /// Given a [`sort`] that may contain aggregate sorts ([tuples] or [records]), it destructs the sort - /// recursively, generating multiple fresh variables and returning the "eta-expanded" tuple of fresh - /// variables. This is in contrast to generating a single fresh variable of tuple sort. + /// Given a [`sort`] that may contain aggregate sorts ([tuple] or [adt]), it destructs the sort + /// recursively, generating multiple fresh variables and returning an "eta-expanded" expression + /// of fresh variables. This is in contrast to generating a single fresh variable of aggregate + /// sort. /// /// For example, given the sort `(int, (bool, int))` it returns `(a0, (a1, a2))` for fresh variables /// `a0: int`, `a1: bool`, and `a2: int`. /// /// [`sort`]: Sort - /// [tuples]: Sort::Tuple - /// [records]: Sort::Adt + /// [tuple]: Sort::Tuple + /// [adt]: flux_middle::rty::SortCtor::Adt pub(crate) fn define_vars(&mut self, sort: &Sort) -> Expr { Expr::fold_sort(sort, |sort| Expr::fvar(self.define_var(sort))) } @@ -321,7 +322,7 @@ impl<'a, 'rcx> Unpacker<'a, 'rcx> { impl TypeFolder for Unpacker<'_, '_> { fn fold_ty(&mut self, ty: &Ty) -> Ty { match ty.kind() { - TyKind::Indexed(bty, idxs) => Ty::indexed(bty.fold_with(self), idxs.clone()), + TyKind::Indexed(bty, idx) => Ty::indexed(bty.fold_with(self), idx.clone()), TyKind::Exists(bound_ty) if self.unpack_exists => { // HACK(nilehmann) In general we shouldn't unpack through mutable references because // that makes referent type too specific. We only have this as a workaround to infer diff --git a/crates/flux-refineck/src/type_env.rs b/crates/flux-refineck/src/type_env.rs index ee2512f565..db3eb41c46 100644 --- a/crates/flux-refineck/src/type_env.rs +++ b/crates/flux-refineck/src/type_env.rs @@ -7,12 +7,12 @@ use flux_middle::{ global_env::GlobalEnv, intern::List, rty::{ - self, box_args, + canonicalize::ShallowHoister, evars::EVarSol, - fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeVisitable, TypeVisitor}, + fold::{FallibleTypeFolder, TypeFoldable, TypeVisitable, TypeVisitor}, subst::RegionSubst, - BaseTy, Binder, BoundVariableKind, Expr, ExprKind, GenericArg, HoleKind, Mutability, Path, - PtrKind, Region, SortCtor, Ty, TyKind, INNERMOST, + BaseTy, Binder, BoundReftKind, Expr, ExprKind, GenericArg, HoleKind, Mutability, Path, + PtrKind, Region, SortCtor, SubsetTy, Ty, TyKind, INNERMOST, }, rustc::mir::{BasicBlock, Local, LocalDecls, Place, PlaceElem}, }; @@ -373,8 +373,9 @@ impl BasicBlockEnvShape { fn pack_generic_arg(scope: &Scope, arg: &GenericArg) -> GenericArg { match arg { GenericArg::Ty(ty) => GenericArg::Ty(Self::pack_ty(scope, ty)), - GenericArg::BaseTy(arg) => { - GenericArg::BaseTy(arg.as_ref().map(|ty| Self::pack_ty(scope, ty))) + GenericArg::Base(arg) => { + assert!(!scope.has_free_vars(arg)); + GenericArg::Base(arg.clone()) } GenericArg::Lifetime(re) => GenericArg::Lifetime(*re), GenericArg::Const(c) => GenericArg::Const(c.clone()), @@ -489,11 +490,7 @@ impl BasicBlockEnvShape { e1.clone() } else { bound_sorts.push(sort.clone()); - Expr::late_bvar( - INNERMOST, - (bound_sorts.len() - 1) as u32, - rty::BoundReftKind::Annon, - ) + Expr::late_bvar(INNERMOST, (bound_sorts.len() - 1) as u32, BoundReftKind::Annon) } } } @@ -533,8 +530,20 @@ impl BasicBlockEnvShape { fn join_generic_arg(&self, arg1: &GenericArg, arg2: &GenericArg) -> GenericArg { match (arg1, arg2) { (GenericArg::Ty(ty1), GenericArg::Ty(ty2)) => GenericArg::Ty(self.join_ty(ty1, ty2)), - (GenericArg::BaseTy(_), GenericArg::BaseTy(_)) => { - tracked_span_bug!("generic argument join for base types is not implemented") + (GenericArg::Base(ctor1), GenericArg::Base(ctor2)) => { + let sty1 = ctor1.as_ref().skip_binder(); + let sty2 = ctor2.as_ref().skip_binder(); + debug_assert_eq3!(&sty1.idx, &sty2.idx, &Expr::nu()); + + let bty = self.join_bty(&sty1.bty, &sty2.bty); + let pred = if self.scope.has_free_vars(&sty2.pred) || sty1.pred != sty2.pred { + Expr::hole(HoleKind::Pred) + } else { + sty1.pred.clone() + }; + let sort = bty.sort(); + let ctor = Binder::with_sort(SubsetTy::new(bty, Expr::nu(), pred), sort); + GenericArg::Base(ctor) } (GenericArg::Lifetime(re1), GenericArg::Lifetime(re2)) => { debug_assert_eq!(re1, re2); @@ -547,9 +556,9 @@ impl BasicBlockEnvShape { pub fn into_bb_env(self, kvar_store: &mut KVarStore) -> BasicBlockEnv { let mut bindings = self.bindings; - let mut generalizer = Generalizer::new(); - bindings.fmap_mut(|ty| generalizer.generalize(ty)); - let (vars, preds) = generalizer.into_parts(); + let mut hoister = ShallowHoister::default(); + bindings.fmap_mut(|ty| hoister.hoist(ty)); + let (vars, preds) = hoister.into_parts(); // Replace all holes with a single fresh kvar on all parameters let mut constrs = preds @@ -578,63 +587,6 @@ impl BasicBlockEnvShape { } } -struct Generalizer { - vars: Vec, - preds: Vec, -} - -impl Generalizer { - fn new() -> Self { - Self { vars: vec![], preds: vec![] } - } - - fn into_parts(self) -> (List, Vec) { - (List::from_vec(self.vars), self.preds) - } - - fn generalize(&mut self, ty: &Ty) -> Ty { - ty.fold_with(self) - } -} - -impl TypeFolder for Generalizer { - fn fold_ty(&mut self, ty: &Ty) -> Ty { - match ty.kind() { - TyKind::Exists(ty) => { - ty.replace_bound_refts_with(|sort, mode, kind| { - let idx = self.vars.len(); - self.vars - .push(BoundVariableKind::Refine(sort.clone(), mode, kind)); - Expr::late_bvar(INNERMOST, idx as u32, kind) - }) - .fold_with(self) - } - TyKind::Constr(pred, ty) => { - self.preds.push(pred.clone()); - ty.fold_with(self) - } - _ => ty.clone(), - } - } - - fn fold_bty(&mut self, bty: &BaseTy) -> BaseTy { - match bty { - BaseTy::Adt(adt_def, args) if adt_def.is_box() => { - let (boxed, alloc) = box_args(args); - let args = List::from_arr([ - GenericArg::Ty(boxed.fold_with(self)), - GenericArg::Ty(alloc.clone()), - ]); - BaseTy::Adt(adt_def.clone(), args) - } - BaseTy::Ref(re, ty, Mutability::Not) => { - BaseTy::Ref(*re, ty.fold_with(self), Mutability::Not) - } - _ => bty.clone(), - } - } -} - impl TypeVisitable for BasicBlockEnvData { fn visit_with(&self, _visitor: &mut V) -> ControlFlow { unimplemented!() diff --git a/crates/flux-refineck/src/type_env/place_ty.rs b/crates/flux-refineck/src/type_env/place_ty.rs index b1319da208..51f3870f0d 100644 --- a/crates/flux-refineck/src/type_env/place_ty.rs +++ b/crates/flux-refineck/src/type_env/place_ty.rs @@ -847,7 +847,7 @@ fn fold( .map(|ty| fold(bindings, rcx, gen, ty, is_strg)) .try_collect_vec()?; - let partially_moved = fields.iter().any(|ty| ty.is_uninit()); + let partially_moved = fields.iter().any(Ty::is_uninit); let ty = if partially_moved { Ty::uninit() } else { @@ -868,7 +868,7 @@ fn fold( .map(|ty| fold(bindings, rcx, gen, ty, is_strg)) .try_collect_vec()?; - let partially_moved = fields.iter().any(|ty| ty.is_uninit()); + let partially_moved = fields.iter().any(Ty::is_uninit); let ty = if partially_moved { Ty::uninit() } else { Ty::tuple(fields) }; Ok(ty) } diff --git a/crates/flux-syntax/src/grammar.lalrpop b/crates/flux-syntax/src/grammar.lalrpop index 5313494f6b..750434b1fc 100644 --- a/crates/flux-syntax/src/grammar.lalrpop +++ b/crates/flux-syntax/src/grammar.lalrpop @@ -41,7 +41,6 @@ GenericParam: surface::GenericParam = { let kind = match kind.as_str() { "type" => surface::GenericParamKind::Type, "base" => surface::GenericParamKind::Base, - "spl" => surface::GenericParamKind::Spl, _ => return Err(ParseError::User { error: UserParseError::UnexpectedToken(lo, hi) }) }; Ok(surface::GenericParam { name, kind, node_id: cx.next_node_id() }) diff --git a/crates/flux-syntax/src/surface.rs b/crates/flux-syntax/src/surface.rs index a9594f4d2a..c5676bed14 100644 --- a/crates/flux-syntax/src/surface.rs +++ b/crates/flux-syntax/src/surface.rs @@ -72,7 +72,6 @@ pub struct GenericParam { #[derive(Debug)] pub enum GenericParamKind { Type, - Spl, Base, Refine { sort: Sort }, } diff --git a/crates/flux-syntax/src/surface/visit.rs b/crates/flux-syntax/src/surface/visit.rs index 6ec98e0b84..9065a5975d 100644 --- a/crates/flux-syntax/src/surface/visit.rs +++ b/crates/flux-syntax/src/surface/visit.rs @@ -203,7 +203,7 @@ pub fn walk_generic_param(vis: &mut V, param: &GenericParam) { vis.visit_ident(param.name); match ¶m.kind { GenericParamKind::Refine { sort } => vis.visit_sort(sort), - GenericParamKind::Type | GenericParamKind::Spl | GenericParamKind::Base => {} + GenericParamKind::Type | GenericParamKind::Base => {} } } diff --git a/crates/flux-tests/tests/neg/error_messages/wf/kinds00.rs b/crates/flux-tests/tests/neg/error_messages/wf/kinds00.rs index 012a6e4b69..4cafecd3de 100644 --- a/crates/flux-tests/tests/neg/error_messages/wf/kinds00.rs +++ b/crates/flux-tests/tests/neg/error_messages/wf/kinds00.rs @@ -13,3 +13,6 @@ pub fn test00_bad() -> RSet { //~^ ERROR values of this type cannot be used as base sorted instances RSet::::new() } + +#[flux::sig(fn(soup:RSet))] //~ ERROR values of this type cannot be used as base sorted instances +pub fn test01(_s: RSet) {} diff --git a/crates/flux-tests/tests/neg/error_messages/wf/kinds01.rs b/crates/flux-tests/tests/neg/error_messages/wf/kinds01.rs index 5e01e1ce2a..e621454daa 100644 --- a/crates/flux-tests/tests/neg/error_messages/wf/kinds01.rs +++ b/crates/flux-tests/tests/neg/error_messages/wf/kinds01.rs @@ -5,11 +5,12 @@ use std::hash::Hash; use rset::RSet; -#[flux::sig(fn(soup:RSet))] //~ ERROR values of this type cannot be used as base sorted instances -pub fn test04(_s: RSet) {} - +// This error is confusing. The problem is that `RSet` cannot be instantiated with `T` because it is +// of kind `type`. But we don't know that until after we convert the argument to `rty` and we check if +// it is a valid simple type. The reason we only know that after conv is that we need to expand type +// aliases. #[flux::sig(fn(RSet[@salt]))] //~ ERROR type cannot be refined -pub fn test05(_s: RSet) +pub fn test01(_s: RSet) where T: Eq + Hash, {