allow out-of-class defaulting of comparison operators
authorrobert <robert@openbsd.org>
Sun, 26 Mar 2023 09:09:31 +0000 (09:09 +0000)
committerrobert <robert@openbsd.org>
Sun, 26 Mar 2023 09:09:31 +0000 (09:09 +0000)
ok deraadt@, mortimer@

this is backport of the following upstream commit:

commit 5fbe21a7748f91adbd1b16c95bbfe180642320a3
Author: Nathan Sidwell <nathan@acm.org>
Date:   Wed Jun 16 10:43:43 2021 -0700

    [clang] p2085 out-of-class comparison operator defaulting

    This implements p2085, allowing out-of-class defaulting of comparison
    operators, primarily so they need not be inline, IIUC intent. this was
    mostly straigh forward, but required reimplementing
    Sema::CheckExplicitlyDefaultedComparison, as now there's a case where
    we have no a priori clue as to what class a defaulted comparison may
    be for. We have to inspect the parameter types to find out. Eg:

    class X { ... };
    bool operator==(X, X) = default;

    Thus reimplemented the parameter type checking, and added 'is this a
    friend' functionality for the above case.

gnu/llvm/clang/include/clang/Basic/DiagnosticSemaKinds.td
gnu/llvm/clang/lib/Sema/SemaDeclCXX.cpp

index 845c922..f7ed458 100644 (file)
@@ -9100,15 +9100,22 @@ def warn_cxx17_compat_defaulted_comparison : Warning<
   "before C++20">, InGroup<CXXPre20Compat>, DefaultIgnore;
 def err_defaulted_comparison_template : Error<
   "comparison operator template cannot be defaulted">;
-def err_defaulted_comparison_out_of_class : Error<
-  "%sub{select_defaulted_comparison_kind}0 can only be defaulted in a class "
-  "definition">;
+def err_defaulted_comparison_num_args : Error<
+  "%select{non-member|member}0 %sub{select_defaulted_comparison_kind}1"
+  " comparison operator must have %select{2|1}0 parameters">;
 def err_defaulted_comparison_param : Error<
   "invalid parameter type for defaulted %sub{select_defaulted_comparison_kind}0"
   "; found %1, expected %2%select{| or %4}3">;
+def err_defaulted_comparison_param_unknown : Error<
+  "invalid parameter type for non-member defaulted"
+  " %sub{select_defaulted_comparison_kind}0"
+  "; found %1, expected class or reference to a constant class">;
 def err_defaulted_comparison_param_mismatch : Error<
   "parameters for defaulted %sub{select_defaulted_comparison_kind}0 "
   "must have the same type%diff{ (found $ vs $)|}1,2">;
+def err_defaulted_comparison_not_friend : Error<
+  "%sub{select_defaulted_comparison_kind}0 is not a friend of"
+  " %select{|incomplete class }1%2">;
 def err_defaulted_comparison_non_const : Error<
   "defaulted member %sub{select_defaulted_comparison_kind}0 must be "
   "const-qualified">;
index ac01beb..23020a9 100644 (file)
@@ -8371,9 +8371,6 @@ bool Sema::CheckExplicitlyDefaultedComparison(Scope *S, FunctionDecl *FD,
                                               DefaultedComparisonKind DCK) {
   assert(DCK != DefaultedComparisonKind::None && "not a defaulted comparison");
 
-  CXXRecordDecl *RD = dyn_cast<CXXRecordDecl>(FD->getLexicalDeclContext());
-  assert(RD && "defaulted comparison is not defaulted in a class");
-
   // Perform any unqualified lookups we're going to need to default this
   // function.
   if (S) {
@@ -8391,43 +8388,17 @@ bool Sema::CheckExplicitlyDefaultedComparison(Scope *S, FunctionDecl *FD,
   //       const C&, or
   //    -- a friend of C having two parameters of type const C& or two
   //       parameters of type C.
-  QualType ExpectedParmType1 = Context.getRecordType(RD);
-  QualType ExpectedParmType2 =
-      Context.getLValueReferenceType(ExpectedParmType1.withConst());
-  if (isa<CXXMethodDecl>(FD))
-    ExpectedParmType1 = ExpectedParmType2;
-  for (const ParmVarDecl *Param : FD->parameters()) {
-    if (!Param->getType()->isDependentType() &&
-        !Context.hasSameType(Param->getType(), ExpectedParmType1) &&
-        !Context.hasSameType(Param->getType(), ExpectedParmType2)) {
-      // Don't diagnose an implicit 'operator=='; we will have diagnosed the
-      // corresponding defaulted 'operator<=>' already.
-      if (!FD->isImplicit()) {
-        Diag(FD->getLocation(), diag::err_defaulted_comparison_param)
-            << (int)DCK << Param->getType() << ExpectedParmType1
-            << !isa<CXXMethodDecl>(FD)
-            << ExpectedParmType2 << Param->getSourceRange();
-      }
-      return true;
-    }
-  }
-  if (FD->getNumParams() == 2 &&
-      !Context.hasSameType(FD->getParamDecl(0)->getType(),
-                           FD->getParamDecl(1)->getType())) {
-    if (!FD->isImplicit()) {
-      Diag(FD->getLocation(), diag::err_defaulted_comparison_param_mismatch)
-          << (int)DCK
-          << FD->getParamDecl(0)->getType()
-          << FD->getParamDecl(0)->getSourceRange()
-          << FD->getParamDecl(1)->getType()
-          << FD->getParamDecl(1)->getSourceRange();
-    }
-    return true;
-  }
 
-  // ... non-static const member ...
-  if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
+  CXXRecordDecl *RD = dyn_cast<CXXRecordDecl>(FD->getLexicalDeclContext());
+  bool IsMethod = isa<CXXMethodDecl>(FD);
+  if (IsMethod) {
+    auto *MD = cast<CXXMethodDecl>(FD);
     assert(!MD->isStatic() && "comparison function cannot be a static member");
+
+    // If we're out-of-class, this is the class we're comparing.
+    if (!RD)
+      RD = MD->getParent();
+
     if (!MD->isConst()) {
       SourceLocation InsertLoc;
       if (FunctionTypeLoc Loc = MD->getFunctionTypeLoc())
@@ -8436,7 +8407,7 @@ bool Sema::CheckExplicitlyDefaultedComparison(Scope *S, FunctionDecl *FD,
       // corresponding defaulted 'operator<=>' already.
       if (!MD->isImplicit()) {
         Diag(MD->getLocation(), diag::err_defaulted_comparison_non_const)
-          << (int)DCK << FixItHint::CreateInsertion(InsertLoc, " const");
+            << (int)DCK << FixItHint::CreateInsertion(InsertLoc, " const");
       }
 
       // Add the 'const' to the type to recover.
@@ -8446,9 +8417,100 @@ bool Sema::CheckExplicitlyDefaultedComparison(Scope *S, FunctionDecl *FD,
       MD->setType(Context.getFunctionType(FPT->getReturnType(),
                                           FPT->getParamTypes(), EPI));
     }
-  } else {
-    // A non-member function declared in a class must be a friend.
+  }
+
+  if (FD->getNumParams() != (IsMethod ? 1 : 2)) {
+    // Let's not worry about using a variadic template pack here -- who would do
+    // such a thing?
+    Diag(FD->getLocation(), diag::err_defaulted_comparison_num_args)
+        << int(IsMethod) << int(DCK);
+    return true;
+  }
+
+  const ParmVarDecl *KnownParm = nullptr;
+  for (const ParmVarDecl *Param : FD->parameters()) {
+    QualType ParmTy = Param->getType();
+    if (ParmTy->isDependentType())
+      continue;
+    if (!KnownParm) {
+      auto CTy = ParmTy;
+      // Is it `T const &`?
+      bool Ok = !IsMethod;
+      QualType ExpectedTy;
+      if (RD)
+        ExpectedTy = Context.getRecordType(RD);
+      if (auto *Ref = CTy->getAs<ReferenceType>()) {
+        CTy = Ref->getPointeeType();
+        if (RD)
+          ExpectedTy.addConst();
+        Ok = true;
+      }
+
+      // Is T a class?
+      if (!Ok) {
+      } else if (RD) {
+        if (!RD->isDependentType() && !Context.hasSameType(CTy, ExpectedTy))
+          Ok = false;
+      } else if (auto *CRD = CTy->getAsRecordDecl()) {
+        RD = cast<CXXRecordDecl>(CRD);
+      } else {
+        Ok = false;
+      }
+
+      if (Ok) {
+        KnownParm = Param;
+      } else {
+        // Don't diagnose an implicit 'operator=='; we will have diagnosed the
+        // corresponding defaulted 'operator<=>' already.
+        if (!FD->isImplicit()) {
+          if (RD) {
+            QualType PlainTy = Context.getRecordType(RD);
+            QualType RefTy =
+                Context.getLValueReferenceType(PlainTy.withConst());
+            if (IsMethod)
+              PlainTy = QualType();
+            Diag(FD->getLocation(), diag::err_defaulted_comparison_param)
+                << int(DCK) << ParmTy << RefTy << int(!IsMethod) << PlainTy
+                << Param->getSourceRange();
+          } else {
+            assert(!IsMethod && "should know expected type for method");
+            Diag(FD->getLocation(),
+                 diag::err_defaulted_comparison_param_unknown)
+                << int(DCK) << ParmTy << Param->getSourceRange();
+          }
+        }
+        return true;
+      }
+    } else if (!Context.hasSameType(KnownParm->getType(), ParmTy)) {
+      Diag(FD->getLocation(), diag::err_defaulted_comparison_param_mismatch)
+          << int(DCK) << KnownParm->getType() << KnownParm->getSourceRange()
+          << ParmTy << Param->getSourceRange();
+      return true;
+    }
+  }
+
+  assert(RD && "must have determined class");
+  if (IsMethod) {
+  } else if (isa<CXXRecordDecl>(FD->getLexicalDeclContext())) {
+    // In-class, must be a friend decl.
     assert(FD->getFriendObjectKind() && "expected a friend declaration");
+  } else {
+    // Out of class, require the defaulted comparison to be a friend (of a
+    // complete type).
+    if (RequireCompleteType(FD->getLocation(), Context.getRecordType(RD),
+                            diag::err_defaulted_comparison_not_friend, int(DCK),
+                            int(1)))
+      return true;
+
+    if (llvm::find_if(RD->friends(), [&](const FriendDecl *F) {
+          return FD->getCanonicalDecl() ==
+                 F->getFriendDecl()->getCanonicalDecl();
+        }) == RD->friends().end()) {
+      Diag(FD->getLocation(), diag::err_defaulted_comparison_not_friend)
+          << int(DCK) << int(0) << RD;
+      Diag(RD->getCanonicalDecl()->getLocation(), diag::note_declared_at);
+      return true;
+    }
   }
 
   // C++2a [class.eq]p1, [class.rel]p1:
@@ -8606,7 +8668,10 @@ void Sema::DefineDefaultedComparison(SourceLocation UseLoc, FunctionDecl *FD,
 
   {
     // Build and set up the function body.
-    CXXRecordDecl *RD = cast<CXXRecordDecl>(FD->getLexicalParent());
+    // The first parameter has type maybe-ref-to maybe-const T, use that to get
+    // the type of the class being compared.
+    auto PT = FD->getParamDecl(0)->getType();
+    CXXRecordDecl *RD = PT.getNonReferenceType()->getAsCXXRecordDecl();
     SourceLocation BodyLoc =
         FD->getEndLoc().isValid() ? FD->getEndLoc() : FD->getLocation();
     StmtResult Body =
@@ -17088,13 +17153,6 @@ void Sema::SetDeclDefaulted(Decl *Dcl, SourceLocation DefaultLoc) {
     return;
   }
 
-  if (DefKind.isComparison() &&
-      !isa<CXXRecordDecl>(FD->getLexicalDeclContext())) {
-    Diag(FD->getLocation(), diag::err_defaulted_comparison_out_of_class)
-        << (int)DefKind.asComparison();
-    return;
-  }
-
   // Issue compatibility warning. We already warned if the operator is
   // 'operator<=>' when parsing the '<=>' token.
   if (DefKind.isComparison() &&
@@ -17116,31 +17174,37 @@ void Sema::SetDeclDefaulted(Decl *Dcl, SourceLocation DefaultLoc) {
   // that we've marked it as defaulted.
   FD->setWillHaveBody(false);
 
-  // If this definition appears within the record, do the checking when
-  // the record is complete. This is always the case for a defaulted
-  // comparison.
-  if (DefKind.isComparison())
+  // If this is a comparison's defaulted definition within the record, do
+  // the checking when the record is complete.
+  if (DefKind.isComparison() && isa<CXXRecordDecl>(FD->getLexicalDeclContext()))
     return;
-  auto *MD = cast<CXXMethodDecl>(FD);
-
-  const FunctionDecl *Primary = FD;
-  if (const FunctionDecl *Pattern = FD->getTemplateInstantiationPattern())
-    // Ask the template instantiation pattern that actually had the
-    // '= default' on it.
-    Primary = Pattern;
 
-  // If the method was defaulted on its first declaration, we will have
+  // If this member fn was defaulted on its first declaration, we will have
   // already performed the checking in CheckCompletedCXXClass. Such a
   // declaration doesn't trigger an implicit definition.
-  if (Primary->getCanonicalDecl()->isDefaulted())
-    return;
+  if (isa<CXXMethodDecl>(FD)) {
+    const FunctionDecl *Primary = FD;
+    if (const FunctionDecl *Pattern = FD->getTemplateInstantiationPattern())
+      // Ask the template instantiation pattern that actually had the
+      // '= default' on it.
+      Primary = Pattern;
+    if (Primary->getCanonicalDecl()->isDefaulted())
+      return;
+  }
 
-  // FIXME: Once we support defining comparisons out of class, check for a
-  // defaulted comparison here.
-  if (CheckExplicitlyDefaultedSpecialMember(MD, DefKind.asSpecialMember()))
-    MD->setInvalidDecl();
-  else
-    DefineDefaultedFunction(*this, MD, DefaultLoc);
+  if (DefKind.isComparison()) {
+    if (CheckExplicitlyDefaultedComparison(nullptr, FD, DefKind.asComparison()))
+      FD->setInvalidDecl();
+    else
+      DefineDefaultedComparison(DefaultLoc, FD, DefKind.asComparison());
+  } else {
+    auto *MD = cast<CXXMethodDecl>(FD);
+
+    if (CheckExplicitlyDefaultedSpecialMember(MD, DefKind.asSpecialMember()))
+      MD->setInvalidDecl();
+    else
+      DefineDefaultedFunction(*this, MD, DefaultLoc);
+  }
 }
 
 static void SearchForReturnInStmt(Sema &Self, Stmt *S) {