[PR testsuite/116860] Testsuite adjustment for recently added tests
[official-gcc.git] / gcc / rust / typecheck / rust-autoderef.cc
blobc302bb9b3fb68c1ce6297e8fde8e33d0f18b6b40
1 // Copyright (C) 2020-2025 Free Software Foundation, Inc.
3 // This file is part of GCC.
5 // GCC is free software; you can redistribute it and/or modify it under
6 // the terms of the GNU General Public License as published by the Free
7 // Software Foundation; either version 3, or (at your option) any later
8 // version.
10 // GCC is distributed in the hope that it will be useful, but WITHOUT ANY
11 // WARRANTY; without even the implied warranty of MERCHANTABILITY or
12 // FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 // for more details.
15 // You should have received a copy of the GNU General Public License
16 // along with GCC; see the file COPYING3. If not see
17 // <http://www.gnu.org/licenses/>.
19 #include "rust-autoderef.h"
20 #include "rust-hir-path-probe.h"
21 #include "rust-hir-dot-operator.h"
22 #include "rust-hir-trait-resolve.h"
23 #include "rust-type-util.h"
24 #include "rust-substitution-mapper.h"
26 namespace Rust {
27 namespace Resolver {
29 static bool
30 resolve_operator_overload_fn (
31 LangItem::Kind lang_item_type, TyTy::BaseType *ty, TyTy::FnType **resolved_fn,
32 Adjustment::AdjustmentType *requires_ref_adjustment);
34 TyTy::BaseType *
35 Adjuster::adjust_type (const std::vector<Adjustment> &adjustments)
37 if (adjustments.size () == 0)
38 return base->clone ();
40 return adjustments.back ().get_expected ()->clone ();
43 Adjustment
44 Adjuster::try_deref_type (TyTy::BaseType *ty, LangItem::Kind deref_lang_item)
46 TyTy::FnType *fn = nullptr;
47 Adjustment::AdjustmentType requires_ref_adjustment
48 = Adjustment::AdjustmentType::ERROR;
49 bool operator_overloaded
50 = resolve_operator_overload_fn (deref_lang_item, ty, &fn,
51 &requires_ref_adjustment);
52 if (!operator_overloaded)
54 return Adjustment::get_error ();
57 auto resolved_base = fn->get_return_type ()->destructure ();
58 bool is_valid_type = resolved_base->get_kind () == TyTy::TypeKind::REF;
59 if (!is_valid_type)
60 return Adjustment::get_error ();
62 TyTy::ReferenceType *ref_base
63 = static_cast<TyTy::ReferenceType *> (resolved_base);
65 Adjustment::AdjustmentType adjustment_type
66 = Adjustment::AdjustmentType::ERROR;
67 switch (deref_lang_item)
69 case LangItem::Kind::DEREF:
70 adjustment_type = Adjustment::AdjustmentType::DEREF;
71 break;
73 case LangItem::Kind::DEREF_MUT:
74 adjustment_type = Adjustment::AdjustmentType::DEREF_MUT;
75 break;
77 default:
78 break;
81 return Adjustment::get_op_overload_deref_adjustment (adjustment_type, ty,
82 ref_base, fn,
83 requires_ref_adjustment);
86 Adjustment
87 Adjuster::try_raw_deref_type (TyTy::BaseType *ty)
89 bool is_valid_type = ty->get_kind () == TyTy::TypeKind::REF;
90 if (!is_valid_type)
91 return Adjustment::get_error ();
93 const TyTy::ReferenceType *ref_base
94 = static_cast<const TyTy::ReferenceType *> (ty);
95 auto infered = ref_base->get_base ()->destructure ();
97 return Adjustment (Adjustment::AdjustmentType::INDIRECTION, ty, infered);
100 Adjustment
101 Adjuster::try_unsize_type (TyTy::BaseType *ty)
103 bool is_valid_type = ty->get_kind () == TyTy::TypeKind::ARRAY;
104 if (!is_valid_type)
105 return Adjustment::get_error ();
107 auto mappings = Analysis::Mappings::get ();
108 auto context = TypeCheckContext::get ();
110 const auto ref_base = static_cast<const TyTy::ArrayType *> (ty);
111 auto slice_elem = ref_base->get_element_type ();
113 auto slice
114 = new TyTy::SliceType (mappings->get_next_hir_id (), ty->get_ident ().locus,
115 TyTy::TyVar (slice_elem->get_ref ()));
116 context->insert_implicit_type (slice);
118 return Adjustment (Adjustment::AdjustmentType::UNSIZE, ty, slice);
121 static bool
122 resolve_operator_overload_fn (
123 LangItem::Kind lang_item_type, TyTy::BaseType *lhs,
124 TyTy::FnType **resolved_fn,
125 Adjustment::AdjustmentType *requires_ref_adjustment)
127 auto context = TypeCheckContext::get ();
128 auto mappings = Analysis::Mappings::get ();
130 // look up lang item for arithmetic type
131 std::string associated_item_name = LangItem::ToString (lang_item_type);
132 DefId respective_lang_item_id = UNKNOWN_DEFID;
133 bool lang_item_defined
134 = mappings->lookup_lang_item (lang_item_type, &respective_lang_item_id);
136 if (!lang_item_defined)
137 return false;
139 // we might be in a static or const context and unknown is fine
140 TypeCheckContextItem current_context = TypeCheckContextItem::get_error ();
141 if (context->have_function_context ())
143 current_context = context->peek_context ();
146 // this flags stops recurisve calls to try and deref when none is available
147 // which will cause an infinite loop
148 bool autoderef_flag = true;
149 auto segment = HIR::PathIdentSegment (associated_item_name);
150 auto candidates = MethodResolver::Probe (lhs, segment, autoderef_flag);
152 // remove any recursive candidates
153 std::set<MethodCandidate> resolved_candidates;
154 for (auto &c : candidates)
156 const TyTy::BaseType *candidate_type = c.candidate.ty;
157 rust_assert (candidate_type->get_kind () == TyTy::TypeKind::FNDEF);
159 const TyTy::FnType &fn
160 = *static_cast<const TyTy::FnType *> (candidate_type);
162 DefId current_fn_defid = current_context.get_defid ();
163 bool recursive_candidated = fn.get_id () == current_fn_defid;
164 if (!recursive_candidated)
166 resolved_candidates.insert (c);
170 auto selected_candidates
171 = MethodResolver::Select (resolved_candidates, lhs, {});
172 bool have_implementation_for_lang_item = selected_candidates.size () > 0;
173 if (!have_implementation_for_lang_item)
174 return false;
176 if (selected_candidates.size () > 1)
178 // no need to error out as we are just trying to see if there is a fit
179 return false;
182 // Get the adjusted self
183 MethodCandidate candidate = *selected_candidates.begin ();
184 Adjuster adj (lhs);
185 TyTy::BaseType *adjusted_self = adj.adjust_type (candidate.adjustments);
187 PathProbeCandidate &resolved_candidate = candidate.candidate;
188 TyTy::BaseType *lookup_tyty = candidate.candidate.ty;
189 rust_assert (lookup_tyty->get_kind () == TyTy::TypeKind::FNDEF);
190 TyTy::BaseType *lookup = lookup_tyty;
191 TyTy::FnType *fn = static_cast<TyTy::FnType *> (lookup);
192 rust_assert (fn->is_method ());
194 rust_debug ("is_impl_item_candidate: %s",
195 resolved_candidate.is_impl_candidate () ? "true" : "false");
197 // in the case where we resolve to a trait bound we have to be careful we are
198 // able to do so there is a case where we are currently resolving the deref
199 // operator overload function which is generic and this might resolve to the
200 // trait item of deref which is not valid as its just another recursive case
201 if (current_context.get_type () == TypeCheckContextItem::ItemType::IMPL_ITEM)
203 auto &impl_item = current_context.get_impl_item ();
204 HIR::ImplBlock *parent = impl_item.first;
205 HIR::Function *fn = impl_item.second;
207 if (parent->has_trait_ref ()
208 && fn->get_function_name ().as_string ().compare (
209 associated_item_name)
210 == 0)
212 TraitReference *trait_reference
213 = TraitResolver::Lookup (*parent->get_trait_ref ().get ());
214 if (!trait_reference->is_error ())
216 TyTy::BaseType *lookup = nullptr;
217 bool ok = context->lookup_type (fn->get_mappings ().get_hirid (),
218 &lookup);
219 rust_assert (ok);
220 rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF);
222 TyTy::FnType *fntype = static_cast<TyTy::FnType *> (lookup);
223 rust_assert (fntype->is_method ());
225 bool is_lang_item_impl
226 = trait_reference->get_mappings ().get_defid ()
227 == respective_lang_item_id;
228 bool self_is_lang_item_self
229 = fntype->get_self_type ()->is_equal (*adjusted_self);
230 bool recursive_operator_overload
231 = is_lang_item_impl && self_is_lang_item_self;
233 if (recursive_operator_overload)
234 return false;
239 // we found a valid operator overload
240 fn->prepare_higher_ranked_bounds ();
241 rust_debug ("resolved operator overload to: {%u} {%s}",
242 candidate.candidate.ty->get_ref (),
243 candidate.candidate.ty->debug_str ().c_str ());
245 if (fn->needs_substitution ())
247 if (lhs->get_kind () == TyTy::TypeKind::ADT)
249 const TyTy::ADTType *adt = static_cast<const TyTy::ADTType *> (lhs);
251 auto s = fn->get_self_type ()->get_root ();
252 rust_assert (s->can_eq (adt, false));
253 rust_assert (s->get_kind () == TyTy::TypeKind::ADT);
254 const TyTy::ADTType *self_adt
255 = static_cast<const TyTy::ADTType *> (s);
257 // we need to grab the Self substitutions as the inherit type
258 // parameters for this
259 if (self_adt->needs_substitution ())
261 rust_assert (adt->was_substituted ());
263 TyTy::SubstitutionArgumentMappings used_args_in_prev_segment
264 = GetUsedSubstArgs::From (adt);
266 TyTy::SubstitutionArgumentMappings inherit_type_args
267 = self_adt->solve_mappings_from_receiver_for_self (
268 used_args_in_prev_segment);
270 // there may or may not be inherited type arguments
271 if (!inherit_type_args.is_error ())
273 // need to apply the inherited type arguments to the
274 // function
275 lookup = fn->handle_substitions (inherit_type_args);
279 else
281 rust_assert (candidate.adjustments.size () < 2);
283 // lets infer the params for this we could probably fix this up by
284 // actually just performing a substitution of a single param but this
285 // seems more generic i think.
287 // this is the case where we had say Foo<&Bar>> and we have derefed to
288 // the &Bar and we are trying to match a method self of Bar which
289 // requires another deref which is matched to the deref trait impl of
290 // &&T so this requires another reference and deref call
292 lookup = fn->infer_substitions (UNDEF_LOCATION);
293 rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF);
294 fn = static_cast<TyTy::FnType *> (lookup);
296 location_t unify_locus = mappings->lookup_location (lhs->get_ref ());
297 unify_site (lhs->get_ref (),
298 TyTy::TyWithLocation (fn->get_self_type ()),
299 TyTy::TyWithLocation (adjusted_self), unify_locus);
301 lookup = fn;
305 if (candidate.adjustments.size () > 0)
306 *requires_ref_adjustment = candidate.adjustments.at (0).get_type ();
308 *resolved_fn = static_cast<TyTy::FnType *> (lookup);
310 return true;
313 AutoderefCycle::AutoderefCycle (bool autoderef_flag)
314 : autoderef_flag (autoderef_flag)
317 AutoderefCycle::~AutoderefCycle () {}
319 void
320 AutoderefCycle::try_hook (const TyTy::BaseType &)
323 bool
324 AutoderefCycle::cycle (TyTy::BaseType *receiver)
326 TyTy::BaseType *r = receiver;
327 while (true)
329 rust_debug ("autoderef try 1: {%s}", r->debug_str ().c_str ());
330 if (try_autoderefed (r))
331 return true;
333 // 4. deref to to 1, if cannot deref then quit
334 if (autoderef_flag)
335 return false;
337 // try unsize
338 Adjustment unsize = Adjuster::try_unsize_type (r);
339 if (!unsize.is_error ())
341 adjustments.push_back (unsize);
342 auto unsize_r = unsize.get_expected ();
344 rust_debug ("autoderef try unsize: {%s}",
345 unsize_r->debug_str ().c_str ());
346 if (try_autoderefed (unsize_r))
347 return true;
349 adjustments.pop_back ();
352 bool is_ptr = receiver->get_kind () == TyTy::TypeKind::POINTER;
353 if (is_ptr)
355 // deref of raw pointers is unsafe
356 return false;
359 Adjustment deref = Adjuster::try_deref_type (r, LangItem::Kind::DEREF);
360 if (!deref.is_error ())
362 auto deref_r = deref.get_expected ();
363 adjustments.push_back (deref);
365 rust_debug ("autoderef try lang-item DEREF: {%s}",
366 deref_r->debug_str ().c_str ());
367 if (try_autoderefed (deref_r))
368 return true;
370 adjustments.pop_back ();
373 Adjustment deref_mut
374 = Adjuster::try_deref_type (r, LangItem::Kind::DEREF_MUT);
375 if (!deref_mut.is_error ())
377 auto deref_r = deref_mut.get_expected ();
378 adjustments.push_back (deref_mut);
380 rust_debug ("autoderef try lang-item DEREF_MUT: {%s}",
381 deref_r->debug_str ().c_str ());
382 if (try_autoderefed (deref_r))
383 return true;
385 adjustments.pop_back ();
388 if (!deref_mut.is_error ())
390 auto deref_r = deref_mut.get_expected ();
391 adjustments.push_back (deref_mut);
392 Adjustment raw_deref = Adjuster::try_raw_deref_type (deref_r);
393 adjustments.push_back (raw_deref);
394 deref_r = raw_deref.get_expected ();
396 if (try_autoderefed (deref_r))
397 return true;
399 adjustments.pop_back ();
400 adjustments.pop_back ();
403 if (!deref.is_error ())
405 r = deref.get_expected ();
406 adjustments.push_back (deref);
408 Adjustment raw_deref = Adjuster::try_raw_deref_type (r);
409 if (raw_deref.is_error ())
410 return false;
412 r = raw_deref.get_expected ();
413 adjustments.push_back (raw_deref);
415 return false;
418 bool
419 AutoderefCycle::try_autoderefed (TyTy::BaseType *r)
421 try_hook (*r);
423 // 1. try raw
424 if (select (*r))
425 return true;
427 // 2. try ref
428 TyTy::ReferenceType *r1
429 = new TyTy::ReferenceType (r->get_ref (), TyTy::TyVar (r->get_ref ()),
430 Mutability::Imm);
431 adjustments.push_back (
432 Adjustment (Adjustment::AdjustmentType::IMM_REF, r, r1));
433 if (select (*r1))
434 return true;
436 adjustments.pop_back ();
438 // 3. try mut ref
439 TyTy::ReferenceType *r2
440 = new TyTy::ReferenceType (r->get_ref (), TyTy::TyVar (r->get_ref ()),
441 Mutability::Mut);
442 adjustments.push_back (
443 Adjustment (Adjustment::AdjustmentType::MUT_REF, r, r2));
444 if (select (*r2))
445 return true;
447 adjustments.pop_back ();
449 return false;
452 } // namespace Resolver
453 } // namespace Rust