1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
25
26 using namespace clang;
27 using namespace CodeGen;
28
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)29 void CodeGenPGO::setFuncName(StringRef Name,
30 llvm::GlobalValue::LinkageTypes Linkage) {
31 StringRef RawFuncName = Name;
32
33 // Function names may be prefixed with a binary '1' to indicate
34 // that the backend should not modify the symbols due to any platform
35 // naming convention. Do not include that '1' in the PGO profile name.
36 if (RawFuncName[0] == '\1')
37 RawFuncName = RawFuncName.substr(1);
38
39 FuncName = RawFuncName;
40 if (llvm::GlobalValue::isLocalLinkage(Linkage)) {
41 // For local symbols, prepend the main file name to distinguish them.
42 // Do not include the full path in the file name since there's no guarantee
43 // that it will stay the same, e.g., if the files are checked out from
44 // version control in different locations.
45 if (CGM.getCodeGenOpts().MainFileName.empty())
46 FuncName = FuncName.insert(0, "<unknown>:");
47 else
48 FuncName = FuncName.insert(0, CGM.getCodeGenOpts().MainFileName + ":");
49 }
50
51 // If we're generating a profile, create a variable for the name.
52 if (CGM.getCodeGenOpts().ProfileInstrGenerate)
53 createFuncNameVar(Linkage);
54 }
55
setFuncName(llvm::Function * Fn)56 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
57 setFuncName(Fn->getName(), Fn->getLinkage());
58 }
59
createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage)60 void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) {
61 // We generally want to match the function's linkage, but available_externally
62 // and extern_weak both have the wrong semantics, and anything that doesn't
63 // need to link across compilation units doesn't need to be visible at all.
64 if (Linkage == llvm::GlobalValue::ExternalWeakLinkage)
65 Linkage = llvm::GlobalValue::LinkOnceAnyLinkage;
66 else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage)
67 Linkage = llvm::GlobalValue::LinkOnceODRLinkage;
68 else if (Linkage == llvm::GlobalValue::InternalLinkage ||
69 Linkage == llvm::GlobalValue::ExternalLinkage)
70 Linkage = llvm::GlobalValue::PrivateLinkage;
71
72 auto *Value =
73 llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false);
74 FuncNameVar =
75 new llvm::GlobalVariable(CGM.getModule(), Value->getType(), true, Linkage,
76 Value, "__llvm_profile_name_" + FuncName);
77
78 // Hide the symbol so that we correctly get a copy for each executable.
79 if (!llvm::GlobalValue::isLocalLinkage(FuncNameVar->getLinkage()))
80 FuncNameVar->setVisibility(llvm::GlobalValue::HiddenVisibility);
81 }
82
83 namespace {
84 /// \brief Stable hasher for PGO region counters.
85 ///
86 /// PGOHash produces a stable hash of a given function's control flow.
87 ///
88 /// Changing the output of this hash will invalidate all previously generated
89 /// profiles -- i.e., don't do it.
90 ///
91 /// \note When this hash does eventually change (years?), we still need to
92 /// support old hashes. We'll need to pull in the version number from the
93 /// profile data format and use the matching hash function.
94 class PGOHash {
95 uint64_t Working;
96 unsigned Count;
97 llvm::MD5 MD5;
98
99 static const int NumBitsPerType = 6;
100 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
101 static const unsigned TooBig = 1u << NumBitsPerType;
102
103 public:
104 /// \brief Hash values for AST nodes.
105 ///
106 /// Distinct values for AST nodes that have region counters attached.
107 ///
108 /// These values must be stable. All new members must be added at the end,
109 /// and no members should be removed. Changing the enumeration value for an
110 /// AST node will affect the hash of every function that contains that node.
111 enum HashType : unsigned char {
112 None = 0,
113 LabelStmt = 1,
114 WhileStmt,
115 DoStmt,
116 ForStmt,
117 CXXForRangeStmt,
118 ObjCForCollectionStmt,
119 SwitchStmt,
120 CaseStmt,
121 DefaultStmt,
122 IfStmt,
123 CXXTryStmt,
124 CXXCatchStmt,
125 ConditionalOperator,
126 BinaryOperatorLAnd,
127 BinaryOperatorLOr,
128 BinaryConditionalOperator,
129
130 // Keep this last. It's for the static assert that follows.
131 LastHashType
132 };
133 static_assert(LastHashType <= TooBig, "Too many types in HashType");
134
135 // TODO: When this format changes, take in a version number here, and use the
136 // old hash calculation for file formats that used the old hash.
PGOHash()137 PGOHash() : Working(0), Count(0) {}
138 void combine(HashType Type);
139 uint64_t finalize();
140 };
141 const int PGOHash::NumBitsPerType;
142 const unsigned PGOHash::NumTypesPerWord;
143 const unsigned PGOHash::TooBig;
144
145 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
146 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
147 /// The next counter value to assign.
148 unsigned NextCounter;
149 /// The function hash.
150 PGOHash Hash;
151 /// The map of statements to counters.
152 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
153
MapRegionCounters__anonf92c78d20111::MapRegionCounters154 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
155 : NextCounter(0), CounterMap(CounterMap) {}
156
157 // Blocks and lambdas are handled as separate functions, so we need not
158 // traverse them in the parent context.
TraverseBlockExpr__anonf92c78d20111::MapRegionCounters159 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaBody__anonf92c78d20111::MapRegionCounters160 bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
TraverseCapturedStmt__anonf92c78d20111::MapRegionCounters161 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
162
VisitDecl__anonf92c78d20111::MapRegionCounters163 bool VisitDecl(const Decl *D) {
164 switch (D->getKind()) {
165 default:
166 break;
167 case Decl::Function:
168 case Decl::CXXMethod:
169 case Decl::CXXConstructor:
170 case Decl::CXXDestructor:
171 case Decl::CXXConversion:
172 case Decl::ObjCMethod:
173 case Decl::Block:
174 case Decl::Captured:
175 CounterMap[D->getBody()] = NextCounter++;
176 break;
177 }
178 return true;
179 }
180
VisitStmt__anonf92c78d20111::MapRegionCounters181 bool VisitStmt(const Stmt *S) {
182 auto Type = getHashType(S);
183 if (Type == PGOHash::None)
184 return true;
185
186 CounterMap[S] = NextCounter++;
187 Hash.combine(Type);
188 return true;
189 }
getHashType__anonf92c78d20111::MapRegionCounters190 PGOHash::HashType getHashType(const Stmt *S) {
191 switch (S->getStmtClass()) {
192 default:
193 break;
194 case Stmt::LabelStmtClass:
195 return PGOHash::LabelStmt;
196 case Stmt::WhileStmtClass:
197 return PGOHash::WhileStmt;
198 case Stmt::DoStmtClass:
199 return PGOHash::DoStmt;
200 case Stmt::ForStmtClass:
201 return PGOHash::ForStmt;
202 case Stmt::CXXForRangeStmtClass:
203 return PGOHash::CXXForRangeStmt;
204 case Stmt::ObjCForCollectionStmtClass:
205 return PGOHash::ObjCForCollectionStmt;
206 case Stmt::SwitchStmtClass:
207 return PGOHash::SwitchStmt;
208 case Stmt::CaseStmtClass:
209 return PGOHash::CaseStmt;
210 case Stmt::DefaultStmtClass:
211 return PGOHash::DefaultStmt;
212 case Stmt::IfStmtClass:
213 return PGOHash::IfStmt;
214 case Stmt::CXXTryStmtClass:
215 return PGOHash::CXXTryStmt;
216 case Stmt::CXXCatchStmtClass:
217 return PGOHash::CXXCatchStmt;
218 case Stmt::ConditionalOperatorClass:
219 return PGOHash::ConditionalOperator;
220 case Stmt::BinaryConditionalOperatorClass:
221 return PGOHash::BinaryConditionalOperator;
222 case Stmt::BinaryOperatorClass: {
223 const BinaryOperator *BO = cast<BinaryOperator>(S);
224 if (BO->getOpcode() == BO_LAnd)
225 return PGOHash::BinaryOperatorLAnd;
226 if (BO->getOpcode() == BO_LOr)
227 return PGOHash::BinaryOperatorLOr;
228 break;
229 }
230 }
231 return PGOHash::None;
232 }
233 };
234
235 /// A StmtVisitor that propagates the raw counts through the AST and
236 /// records the count at statements where the value may change.
237 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
238 /// PGO state.
239 CodeGenPGO &PGO;
240
241 /// A flag that is set when the current count should be recorded on the
242 /// next statement, such as at the exit of a loop.
243 bool RecordNextStmtCount;
244
245 /// The count at the current location in the traversal.
246 uint64_t CurrentCount;
247
248 /// The map of statements to count values.
249 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
250
251 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
252 struct BreakContinue {
253 uint64_t BreakCount;
254 uint64_t ContinueCount;
BreakContinue__anonf92c78d20111::ComputeRegionCounts::BreakContinue255 BreakContinue() : BreakCount(0), ContinueCount(0) {}
256 };
257 SmallVector<BreakContinue, 8> BreakContinueStack;
258
ComputeRegionCounts__anonf92c78d20111::ComputeRegionCounts259 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
260 CodeGenPGO &PGO)
261 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
262
RecordStmtCount__anonf92c78d20111::ComputeRegionCounts263 void RecordStmtCount(const Stmt *S) {
264 if (RecordNextStmtCount) {
265 CountMap[S] = CurrentCount;
266 RecordNextStmtCount = false;
267 }
268 }
269
270 /// Set and return the current count.
setCount__anonf92c78d20111::ComputeRegionCounts271 uint64_t setCount(uint64_t Count) {
272 CurrentCount = Count;
273 return Count;
274 }
275
VisitStmt__anonf92c78d20111::ComputeRegionCounts276 void VisitStmt(const Stmt *S) {
277 RecordStmtCount(S);
278 for (const Stmt *Child : S->children())
279 if (Child)
280 this->Visit(Child);
281 }
282
VisitFunctionDecl__anonf92c78d20111::ComputeRegionCounts283 void VisitFunctionDecl(const FunctionDecl *D) {
284 // Counter tracks entry to the function body.
285 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
286 CountMap[D->getBody()] = BodyCount;
287 Visit(D->getBody());
288 }
289
290 // Skip lambda expressions. We visit these as FunctionDecls when we're
291 // generating them and aren't interested in the body when generating a
292 // parent context.
VisitLambdaExpr__anonf92c78d20111::ComputeRegionCounts293 void VisitLambdaExpr(const LambdaExpr *LE) {}
294
VisitCapturedDecl__anonf92c78d20111::ComputeRegionCounts295 void VisitCapturedDecl(const CapturedDecl *D) {
296 // Counter tracks entry to the capture body.
297 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
298 CountMap[D->getBody()] = BodyCount;
299 Visit(D->getBody());
300 }
301
VisitObjCMethodDecl__anonf92c78d20111::ComputeRegionCounts302 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
303 // Counter tracks entry to the method body.
304 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
305 CountMap[D->getBody()] = BodyCount;
306 Visit(D->getBody());
307 }
308
VisitBlockDecl__anonf92c78d20111::ComputeRegionCounts309 void VisitBlockDecl(const BlockDecl *D) {
310 // Counter tracks entry to the block body.
311 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
312 CountMap[D->getBody()] = BodyCount;
313 Visit(D->getBody());
314 }
315
VisitReturnStmt__anonf92c78d20111::ComputeRegionCounts316 void VisitReturnStmt(const ReturnStmt *S) {
317 RecordStmtCount(S);
318 if (S->getRetValue())
319 Visit(S->getRetValue());
320 CurrentCount = 0;
321 RecordNextStmtCount = true;
322 }
323
VisitCXXThrowExpr__anonf92c78d20111::ComputeRegionCounts324 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
325 RecordStmtCount(E);
326 if (E->getSubExpr())
327 Visit(E->getSubExpr());
328 CurrentCount = 0;
329 RecordNextStmtCount = true;
330 }
331
VisitGotoStmt__anonf92c78d20111::ComputeRegionCounts332 void VisitGotoStmt(const GotoStmt *S) {
333 RecordStmtCount(S);
334 CurrentCount = 0;
335 RecordNextStmtCount = true;
336 }
337
VisitLabelStmt__anonf92c78d20111::ComputeRegionCounts338 void VisitLabelStmt(const LabelStmt *S) {
339 RecordNextStmtCount = false;
340 // Counter tracks the block following the label.
341 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
342 CountMap[S] = BlockCount;
343 Visit(S->getSubStmt());
344 }
345
VisitBreakStmt__anonf92c78d20111::ComputeRegionCounts346 void VisitBreakStmt(const BreakStmt *S) {
347 RecordStmtCount(S);
348 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
349 BreakContinueStack.back().BreakCount += CurrentCount;
350 CurrentCount = 0;
351 RecordNextStmtCount = true;
352 }
353
VisitContinueStmt__anonf92c78d20111::ComputeRegionCounts354 void VisitContinueStmt(const ContinueStmt *S) {
355 RecordStmtCount(S);
356 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
357 BreakContinueStack.back().ContinueCount += CurrentCount;
358 CurrentCount = 0;
359 RecordNextStmtCount = true;
360 }
361
VisitWhileStmt__anonf92c78d20111::ComputeRegionCounts362 void VisitWhileStmt(const WhileStmt *S) {
363 RecordStmtCount(S);
364 uint64_t ParentCount = CurrentCount;
365
366 BreakContinueStack.push_back(BreakContinue());
367 // Visit the body region first so the break/continue adjustments can be
368 // included when visiting the condition.
369 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
370 CountMap[S->getBody()] = CurrentCount;
371 Visit(S->getBody());
372 uint64_t BackedgeCount = CurrentCount;
373
374 // ...then go back and propagate counts through the condition. The count
375 // at the start of the condition is the sum of the incoming edges,
376 // the backedge from the end of the loop body, and the edges from
377 // continue statements.
378 BreakContinue BC = BreakContinueStack.pop_back_val();
379 uint64_t CondCount =
380 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
381 CountMap[S->getCond()] = CondCount;
382 Visit(S->getCond());
383 setCount(BC.BreakCount + CondCount - BodyCount);
384 RecordNextStmtCount = true;
385 }
386
VisitDoStmt__anonf92c78d20111::ComputeRegionCounts387 void VisitDoStmt(const DoStmt *S) {
388 RecordStmtCount(S);
389 uint64_t LoopCount = PGO.getRegionCount(S);
390
391 BreakContinueStack.push_back(BreakContinue());
392 // The count doesn't include the fallthrough from the parent scope. Add it.
393 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
394 CountMap[S->getBody()] = BodyCount;
395 Visit(S->getBody());
396 uint64_t BackedgeCount = CurrentCount;
397
398 BreakContinue BC = BreakContinueStack.pop_back_val();
399 // The count at the start of the condition is equal to the count at the
400 // end of the body, plus any continues.
401 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
402 CountMap[S->getCond()] = CondCount;
403 Visit(S->getCond());
404 setCount(BC.BreakCount + CondCount - LoopCount);
405 RecordNextStmtCount = true;
406 }
407
VisitForStmt__anonf92c78d20111::ComputeRegionCounts408 void VisitForStmt(const ForStmt *S) {
409 RecordStmtCount(S);
410 if (S->getInit())
411 Visit(S->getInit());
412
413 uint64_t ParentCount = CurrentCount;
414
415 BreakContinueStack.push_back(BreakContinue());
416 // Visit the body region first. (This is basically the same as a while
417 // loop; see further comments in VisitWhileStmt.)
418 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
419 CountMap[S->getBody()] = BodyCount;
420 Visit(S->getBody());
421 uint64_t BackedgeCount = CurrentCount;
422 BreakContinue BC = BreakContinueStack.pop_back_val();
423
424 // The increment is essentially part of the body but it needs to include
425 // the count for all the continue statements.
426 if (S->getInc()) {
427 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
428 CountMap[S->getInc()] = IncCount;
429 Visit(S->getInc());
430 }
431
432 // ...then go back and propagate counts through the condition.
433 uint64_t CondCount =
434 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
435 if (S->getCond()) {
436 CountMap[S->getCond()] = CondCount;
437 Visit(S->getCond());
438 }
439 setCount(BC.BreakCount + CondCount - BodyCount);
440 RecordNextStmtCount = true;
441 }
442
VisitCXXForRangeStmt__anonf92c78d20111::ComputeRegionCounts443 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
444 RecordStmtCount(S);
445 Visit(S->getLoopVarStmt());
446 Visit(S->getRangeStmt());
447 Visit(S->getBeginEndStmt());
448
449 uint64_t ParentCount = CurrentCount;
450 BreakContinueStack.push_back(BreakContinue());
451 // Visit the body region first. (This is basically the same as a while
452 // loop; see further comments in VisitWhileStmt.)
453 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
454 CountMap[S->getBody()] = BodyCount;
455 Visit(S->getBody());
456 uint64_t BackedgeCount = CurrentCount;
457 BreakContinue BC = BreakContinueStack.pop_back_val();
458
459 // The increment is essentially part of the body but it needs to include
460 // the count for all the continue statements.
461 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
462 CountMap[S->getInc()] = IncCount;
463 Visit(S->getInc());
464
465 // ...then go back and propagate counts through the condition.
466 uint64_t CondCount =
467 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
468 CountMap[S->getCond()] = CondCount;
469 Visit(S->getCond());
470 setCount(BC.BreakCount + CondCount - BodyCount);
471 RecordNextStmtCount = true;
472 }
473
VisitObjCForCollectionStmt__anonf92c78d20111::ComputeRegionCounts474 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
475 RecordStmtCount(S);
476 Visit(S->getElement());
477 uint64_t ParentCount = CurrentCount;
478 BreakContinueStack.push_back(BreakContinue());
479 // Counter tracks the body of the loop.
480 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
481 CountMap[S->getBody()] = BodyCount;
482 Visit(S->getBody());
483 uint64_t BackedgeCount = CurrentCount;
484 BreakContinue BC = BreakContinueStack.pop_back_val();
485
486 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
487 BodyCount);
488 RecordNextStmtCount = true;
489 }
490
VisitSwitchStmt__anonf92c78d20111::ComputeRegionCounts491 void VisitSwitchStmt(const SwitchStmt *S) {
492 RecordStmtCount(S);
493 Visit(S->getCond());
494 CurrentCount = 0;
495 BreakContinueStack.push_back(BreakContinue());
496 Visit(S->getBody());
497 // If the switch is inside a loop, add the continue counts.
498 BreakContinue BC = BreakContinueStack.pop_back_val();
499 if (!BreakContinueStack.empty())
500 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
501 // Counter tracks the exit block of the switch.
502 setCount(PGO.getRegionCount(S));
503 RecordNextStmtCount = true;
504 }
505
VisitSwitchCase__anonf92c78d20111::ComputeRegionCounts506 void VisitSwitchCase(const SwitchCase *S) {
507 RecordNextStmtCount = false;
508 // Counter for this particular case. This counts only jumps from the
509 // switch header and does not include fallthrough from the case before
510 // this one.
511 uint64_t CaseCount = PGO.getRegionCount(S);
512 setCount(CurrentCount + CaseCount);
513 // We need the count without fallthrough in the mapping, so it's more useful
514 // for branch probabilities.
515 CountMap[S] = CaseCount;
516 RecordNextStmtCount = true;
517 Visit(S->getSubStmt());
518 }
519
VisitIfStmt__anonf92c78d20111::ComputeRegionCounts520 void VisitIfStmt(const IfStmt *S) {
521 RecordStmtCount(S);
522 uint64_t ParentCount = CurrentCount;
523 Visit(S->getCond());
524
525 // Counter tracks the "then" part of an if statement. The count for
526 // the "else" part, if it exists, will be calculated from this counter.
527 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
528 CountMap[S->getThen()] = ThenCount;
529 Visit(S->getThen());
530 uint64_t OutCount = CurrentCount;
531
532 uint64_t ElseCount = ParentCount - ThenCount;
533 if (S->getElse()) {
534 setCount(ElseCount);
535 CountMap[S->getElse()] = ElseCount;
536 Visit(S->getElse());
537 OutCount += CurrentCount;
538 } else
539 OutCount += ElseCount;
540 setCount(OutCount);
541 RecordNextStmtCount = true;
542 }
543
VisitCXXTryStmt__anonf92c78d20111::ComputeRegionCounts544 void VisitCXXTryStmt(const CXXTryStmt *S) {
545 RecordStmtCount(S);
546 Visit(S->getTryBlock());
547 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
548 Visit(S->getHandler(I));
549 // Counter tracks the continuation block of the try statement.
550 setCount(PGO.getRegionCount(S));
551 RecordNextStmtCount = true;
552 }
553
VisitCXXCatchStmt__anonf92c78d20111::ComputeRegionCounts554 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
555 RecordNextStmtCount = false;
556 // Counter tracks the catch statement's handler block.
557 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
558 CountMap[S] = CatchCount;
559 Visit(S->getHandlerBlock());
560 }
561
VisitAbstractConditionalOperator__anonf92c78d20111::ComputeRegionCounts562 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
563 RecordStmtCount(E);
564 uint64_t ParentCount = CurrentCount;
565 Visit(E->getCond());
566
567 // Counter tracks the "true" part of a conditional operator. The
568 // count in the "false" part will be calculated from this counter.
569 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
570 CountMap[E->getTrueExpr()] = TrueCount;
571 Visit(E->getTrueExpr());
572 uint64_t OutCount = CurrentCount;
573
574 uint64_t FalseCount = setCount(ParentCount - TrueCount);
575 CountMap[E->getFalseExpr()] = FalseCount;
576 Visit(E->getFalseExpr());
577 OutCount += CurrentCount;
578
579 setCount(OutCount);
580 RecordNextStmtCount = true;
581 }
582
VisitBinLAnd__anonf92c78d20111::ComputeRegionCounts583 void VisitBinLAnd(const BinaryOperator *E) {
584 RecordStmtCount(E);
585 uint64_t ParentCount = CurrentCount;
586 Visit(E->getLHS());
587 // Counter tracks the right hand side of a logical and operator.
588 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
589 CountMap[E->getRHS()] = RHSCount;
590 Visit(E->getRHS());
591 setCount(ParentCount + RHSCount - CurrentCount);
592 RecordNextStmtCount = true;
593 }
594
VisitBinLOr__anonf92c78d20111::ComputeRegionCounts595 void VisitBinLOr(const BinaryOperator *E) {
596 RecordStmtCount(E);
597 uint64_t ParentCount = CurrentCount;
598 Visit(E->getLHS());
599 // Counter tracks the right hand side of a logical or operator.
600 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
601 CountMap[E->getRHS()] = RHSCount;
602 Visit(E->getRHS());
603 setCount(ParentCount + RHSCount - CurrentCount);
604 RecordNextStmtCount = true;
605 }
606 };
607 }
608
combine(HashType Type)609 void PGOHash::combine(HashType Type) {
610 // Check that we never combine 0 and only have six bits.
611 assert(Type && "Hash is invalid: unexpected type 0");
612 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
613
614 // Pass through MD5 if enough work has built up.
615 if (Count && Count % NumTypesPerWord == 0) {
616 using namespace llvm::support;
617 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
618 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
619 Working = 0;
620 }
621
622 // Accumulate the current type.
623 ++Count;
624 Working = Working << NumBitsPerType | Type;
625 }
626
finalize()627 uint64_t PGOHash::finalize() {
628 // Use Working as the hash directly if we never used MD5.
629 if (Count <= NumTypesPerWord)
630 // No need to byte swap here, since none of the math was endian-dependent.
631 // This number will be byte-swapped as required on endianness transitions,
632 // so we will see the same value on the other side.
633 return Working;
634
635 // Check for remaining work in Working.
636 if (Working)
637 MD5.update(Working);
638
639 // Finalize the MD5 and return the hash.
640 llvm::MD5::MD5Result Result;
641 MD5.final(Result);
642 using namespace llvm::support;
643 return endian::read<uint64_t, little, unaligned>(Result);
644 }
645
checkGlobalDecl(GlobalDecl GD)646 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) {
647 // Make sure we only emit coverage mapping for one constructor/destructor.
648 // Clang emits several functions for the constructor and the destructor of
649 // a class. Every function is instrumented, but we only want to provide
650 // coverage for one of them. Because of that we only emit the coverage mapping
651 // for the base constructor/destructor.
652 if ((isa<CXXConstructorDecl>(GD.getDecl()) &&
653 GD.getCtorType() != Ctor_Base) ||
654 (isa<CXXDestructorDecl>(GD.getDecl()) &&
655 GD.getDtorType() != Dtor_Base)) {
656 SkipCoverageMapping = true;
657 }
658 }
659
assignRegionCounters(const Decl * D,llvm::Function * Fn)660 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
661 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
662 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
663 if (!InstrumentRegions && !PGOReader)
664 return;
665 if (D->isImplicit())
666 return;
667 CGM.ClearUnusedCoverageMapping(D);
668 setFuncName(Fn);
669
670 mapRegionCounters(D);
671 if (CGM.getCodeGenOpts().CoverageMapping)
672 emitCounterRegionMapping(D);
673 if (PGOReader) {
674 SourceManager &SM = CGM.getContext().getSourceManager();
675 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
676 computeRegionCounts(D);
677 applyFunctionAttributes(PGOReader, Fn);
678 }
679 }
680
mapRegionCounters(const Decl * D)681 void CodeGenPGO::mapRegionCounters(const Decl *D) {
682 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
683 MapRegionCounters Walker(*RegionCounterMap);
684 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
685 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
686 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
687 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
688 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
689 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
690 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
691 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
692 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
693 NumRegionCounters = Walker.NextCounter;
694 FunctionHash = Walker.Hash.finalize();
695 }
696
emitCounterRegionMapping(const Decl * D)697 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
698 if (SkipCoverageMapping)
699 return;
700 // Don't map the functions inside the system headers
701 auto Loc = D->getBody()->getLocStart();
702 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
703 return;
704
705 std::string CoverageMapping;
706 llvm::raw_string_ostream OS(CoverageMapping);
707 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
708 CGM.getContext().getSourceManager(),
709 CGM.getLangOpts(), RegionCounterMap.get());
710 MappingGen.emitCounterMapping(D, OS);
711 OS.flush();
712
713 if (CoverageMapping.empty())
714 return;
715
716 CGM.getCoverageMapping()->addFunctionMappingRecord(
717 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
718 }
719
720 void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)721 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
722 llvm::GlobalValue::LinkageTypes Linkage) {
723 if (SkipCoverageMapping)
724 return;
725 // Don't map the functions inside the system headers
726 auto Loc = D->getBody()->getLocStart();
727 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
728 return;
729
730 std::string CoverageMapping;
731 llvm::raw_string_ostream OS(CoverageMapping);
732 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
733 CGM.getContext().getSourceManager(),
734 CGM.getLangOpts());
735 MappingGen.emitEmptyMapping(D, OS);
736 OS.flush();
737
738 if (CoverageMapping.empty())
739 return;
740
741 setFuncName(Name, Linkage);
742 CGM.getCoverageMapping()->addFunctionMappingRecord(
743 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
744 }
745
computeRegionCounts(const Decl * D)746 void CodeGenPGO::computeRegionCounts(const Decl *D) {
747 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
748 ComputeRegionCounts Walker(*StmtCountMap, *this);
749 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
750 Walker.VisitFunctionDecl(FD);
751 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
752 Walker.VisitObjCMethodDecl(MD);
753 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
754 Walker.VisitBlockDecl(BD);
755 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
756 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
757 }
758
759 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)760 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
761 llvm::Function *Fn) {
762 if (!haveRegionCounts())
763 return;
764
765 uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
766 uint64_t FunctionCount = getRegionCount(0);
767 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
768 // Turn on InlineHint attribute for hot functions.
769 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
770 Fn->addFnAttr(llvm::Attribute::InlineHint);
771 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
772 // Turn on Cold attribute for cold functions.
773 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
774 Fn->addFnAttr(llvm::Attribute::Cold);
775
776 Fn->setEntryCount(FunctionCount);
777 }
778
emitCounterIncrement(CGBuilderTy & Builder,const Stmt * S)779 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) {
780 if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
781 return;
782 if (!Builder.GetInsertPoint())
783 return;
784
785 unsigned Counter = (*RegionCounterMap)[S];
786 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
787 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
788 {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
789 Builder.getInt64(FunctionHash),
790 Builder.getInt32(NumRegionCounters),
791 Builder.getInt32(Counter)});
792 }
793
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)794 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
795 bool IsInMainFile) {
796 CGM.getPGOStats().addVisited(IsInMainFile);
797 RegionCounts.clear();
798 if (std::error_code EC =
799 PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
800 if (EC == llvm::instrprof_error::unknown_function)
801 CGM.getPGOStats().addMissing(IsInMainFile);
802 else if (EC == llvm::instrprof_error::hash_mismatch)
803 CGM.getPGOStats().addMismatched(IsInMainFile);
804 else if (EC == llvm::instrprof_error::malformed)
805 // TODO: Consider a more specific warning for this case.
806 CGM.getPGOStats().addMismatched(IsInMainFile);
807 RegionCounts.clear();
808 }
809 }
810
811 /// \brief Calculate what to divide by to scale weights.
812 ///
813 /// Given the maximum weight, calculate a divisor that will scale all the
814 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)815 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
816 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
817 }
818
819 /// \brief Scale an individual branch weight (and add 1).
820 ///
821 /// Scale a 64-bit weight down to 32-bits using \c Scale.
822 ///
823 /// According to Laplace's Rule of Succession, it is better to compute the
824 /// weight based on the count plus 1, so universally add 1 to the value.
825 ///
826 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
827 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)828 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
829 assert(Scale && "scale by 0?");
830 uint64_t Scaled = Weight / Scale + 1;
831 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
832 return Scaled;
833 }
834
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount)835 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
836 uint64_t FalseCount) {
837 // Check for empty weights.
838 if (!TrueCount && !FalseCount)
839 return nullptr;
840
841 // Calculate how to scale down to 32-bits.
842 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
843
844 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
845 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
846 scaleBranchWeight(FalseCount, Scale));
847 }
848
849 llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights)850 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
851 // We need at least two elements to create meaningful weights.
852 if (Weights.size() < 2)
853 return nullptr;
854
855 // Check for empty weights.
856 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
857 if (MaxWeight == 0)
858 return nullptr;
859
860 // Calculate how to scale down to 32-bits.
861 uint64_t Scale = calculateWeightScale(MaxWeight);
862
863 SmallVector<uint32_t, 16> ScaledWeights;
864 ScaledWeights.reserve(Weights.size());
865 for (uint64_t W : Weights)
866 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
867
868 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
869 return MDHelper.createBranchWeights(ScaledWeights);
870 }
871
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount)872 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
873 uint64_t LoopCount) {
874 if (!PGO.haveRegionCounts())
875 return nullptr;
876 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
877 assert(CondCount.hasValue() && "missing expected loop condition count");
878 if (*CondCount == 0)
879 return nullptr;
880 return createProfileWeights(LoopCount,
881 std::max(*CondCount, LoopCount) - LoopCount);
882 }
883