1 | // |
2 | // Copyright 2018 The ANGLE Project Authors. All rights reserved. |
3 | // Use of this source code is governed by a BSD-style license that can be |
4 | // found in the LICENSE file. |
5 | // |
6 | // RewriteStructSamplers: Extract structs from samplers. |
7 | // |
8 | |
9 | #include "compiler/translator/tree_ops/RewriteStructSamplers.h" |
10 | |
11 | #include "compiler/translator/ImmutableStringBuilder.h" |
12 | #include "compiler/translator/SymbolTable.h" |
13 | #include "compiler/translator/tree_util/IntermTraverse.h" |
14 | |
15 | namespace sh |
16 | { |
17 | namespace |
18 | { |
19 | // Helper method to get the sampler extracted struct type of a parameter. |
20 | TType *GetStructSamplerParameterType(TSymbolTable *symbolTable, const TVariable ¶m) |
21 | { |
22 | const TStructure *structure = param.getType().getStruct(); |
23 | const TSymbol *structSymbol = symbolTable->findUserDefined(structure->name()); |
24 | ASSERT(structSymbol && structSymbol->isStruct()); |
25 | const TStructure *structVar = static_cast<const TStructure *>(structSymbol); |
26 | TType *structType = new TType(structVar, false); |
27 | |
28 | if (param.getType().isArray()) |
29 | { |
30 | structType->makeArrays(*param.getType().getArraySizes()); |
31 | } |
32 | |
33 | ASSERT(!structType->isStructureContainingSamplers()); |
34 | |
35 | return structType; |
36 | } |
37 | |
38 | TIntermSymbol *ReplaceTypeOfSymbolNode(TIntermSymbol *symbolNode, TSymbolTable *symbolTable) |
39 | { |
40 | const TVariable &oldVariable = symbolNode->variable(); |
41 | |
42 | TType *newType = GetStructSamplerParameterType(symbolTable, oldVariable); |
43 | |
44 | TVariable *newVariable = |
45 | new TVariable(oldVariable.uniqueId(), oldVariable.name(), oldVariable.symbolType(), |
46 | oldVariable.extension(), newType); |
47 | return new TIntermSymbol(newVariable); |
48 | } |
49 | |
50 | TIntermTyped *ReplaceTypeOfTypedStructNode(TIntermTyped *argument, TSymbolTable *symbolTable) |
51 | { |
52 | TIntermSymbol *asSymbol = argument->getAsSymbolNode(); |
53 | if (asSymbol) |
54 | { |
55 | ASSERT(asSymbol->getType().getStruct()); |
56 | return ReplaceTypeOfSymbolNode(asSymbol, symbolTable); |
57 | } |
58 | |
59 | TIntermTyped *replacement = argument->deepCopy(); |
60 | TIntermBinary *binary = replacement->getAsBinaryNode(); |
61 | ASSERT(binary); |
62 | |
63 | while (binary) |
64 | { |
65 | ASSERT(binary->getOp() == EOpIndexDirectStruct || binary->getOp() == EOpIndexDirect); |
66 | |
67 | asSymbol = binary->getLeft()->getAsSymbolNode(); |
68 | |
69 | if (asSymbol) |
70 | { |
71 | ASSERT(asSymbol->getType().getStruct()); |
72 | TIntermSymbol *newSymbol = ReplaceTypeOfSymbolNode(asSymbol, symbolTable); |
73 | binary->replaceChildNode(binary->getLeft(), newSymbol); |
74 | return replacement; |
75 | } |
76 | |
77 | binary = binary->getLeft()->getAsBinaryNode(); |
78 | } |
79 | |
80 | UNREACHABLE(); |
81 | return nullptr; |
82 | } |
83 | |
84 | // Maximum string size of a hex unsigned int. |
85 | constexpr size_t kHexSize = ImmutableStringBuilder::GetHexCharCount<unsigned int>(); |
86 | |
87 | class Traverser final : public TIntermTraverser |
88 | { |
89 | public: |
90 | explicit Traverser(TSymbolTable *symbolTable) |
91 | : TIntermTraverser(true, false, true, symbolTable), mRemovedUniformsCount(0) |
92 | { |
93 | mSymbolTable->push(); |
94 | } |
95 | |
96 | ~Traverser() override { mSymbolTable->pop(); } |
97 | |
98 | int removedUniformsCount() const { return mRemovedUniformsCount; } |
99 | |
100 | // Each struct sampler declaration is stripped of its samplers. New uniforms are added for each |
101 | // stripped struct sampler. |
102 | bool visitDeclaration(Visit visit, TIntermDeclaration *decl) override |
103 | { |
104 | if (visit != PreVisit) |
105 | return true; |
106 | |
107 | if (!mInGlobalScope) |
108 | { |
109 | return true; |
110 | } |
111 | |
112 | const TIntermSequence &sequence = *(decl->getSequence()); |
113 | TIntermTyped *declarator = sequence.front()->getAsTyped(); |
114 | const TType &type = declarator->getType(); |
115 | |
116 | if (type.isStructureContainingSamplers()) |
117 | { |
118 | TIntermSequence *newSequence = new TIntermSequence; |
119 | |
120 | if (type.isStructSpecifier()) |
121 | { |
122 | stripStructSpecifierSamplers(type.getStruct(), newSequence); |
123 | } |
124 | else |
125 | { |
126 | TIntermSymbol *asSymbol = declarator->getAsSymbolNode(); |
127 | ASSERT(asSymbol); |
128 | const TVariable &variable = asSymbol->variable(); |
129 | ASSERT(variable.symbolType() != SymbolType::Empty); |
130 | extractStructSamplerUniforms(decl, variable, type.getStruct(), newSequence); |
131 | } |
132 | |
133 | mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), decl, *newSequence); |
134 | } |
135 | |
136 | return true; |
137 | } |
138 | |
139 | // Each struct sampler reference is replaced with a reference to the new extracted sampler. |
140 | bool visitBinary(Visit visit, TIntermBinary *node) override |
141 | { |
142 | if (visit != PreVisit) |
143 | return true; |
144 | |
145 | if (node->getOp() == EOpIndexDirectStruct && node->getType().isSampler()) |
146 | { |
147 | ImmutableString newName = GetStructSamplerNameFromTypedNode(node); |
148 | const TVariable *samplerReplacement = |
149 | static_cast<const TVariable *>(mSymbolTable->findUserDefined(newName)); |
150 | ASSERT(samplerReplacement); |
151 | |
152 | TIntermSymbol *replacement = new TIntermSymbol(samplerReplacement); |
153 | |
154 | queueReplacement(replacement, OriginalNode::IS_DROPPED); |
155 | return true; |
156 | } |
157 | |
158 | return true; |
159 | } |
160 | |
161 | // In we are passing references to structs containing samplers we must new additional |
162 | // arguments. For each extracted struct sampler a new argument is added. This chains to nested |
163 | // structs. |
164 | void visitFunctionPrototype(TIntermFunctionPrototype *node) override |
165 | { |
166 | const TFunction *function = node->getFunction(); |
167 | |
168 | if (!function->hasSamplerInStructParams()) |
169 | { |
170 | return; |
171 | } |
172 | |
173 | const TSymbol *foundFunction = mSymbolTable->findUserDefined(function->name()); |
174 | if (foundFunction) |
175 | { |
176 | ASSERT(foundFunction->isFunction()); |
177 | function = static_cast<const TFunction *>(foundFunction); |
178 | } |
179 | else |
180 | { |
181 | TFunction *newFunction = createStructSamplerFunction(function); |
182 | mSymbolTable->declareUserDefinedFunction(newFunction, true); |
183 | function = newFunction; |
184 | } |
185 | |
186 | ASSERT(!function->hasSamplerInStructParams()); |
187 | TIntermFunctionPrototype *newProto = new TIntermFunctionPrototype(function); |
188 | queueReplacement(newProto, OriginalNode::IS_DROPPED); |
189 | } |
190 | |
191 | // We insert a new scope for each function definition so we can track the new parameters. |
192 | bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override |
193 | { |
194 | if (visit == PreVisit) |
195 | { |
196 | mSymbolTable->push(); |
197 | } |
198 | else |
199 | { |
200 | ASSERT(visit == PostVisit); |
201 | mSymbolTable->pop(); |
202 | } |
203 | return true; |
204 | } |
205 | |
206 | // For function call nodes we pass references to the extracted struct samplers in that scope. |
207 | bool visitAggregate(Visit visit, TIntermAggregate *node) override |
208 | { |
209 | if (visit != PreVisit) |
210 | return true; |
211 | |
212 | if (!node->isFunctionCall()) |
213 | return true; |
214 | |
215 | const TFunction *function = node->getFunction(); |
216 | if (!function->hasSamplerInStructParams()) |
217 | return true; |
218 | |
219 | ASSERT(node->getOp() == EOpCallFunctionInAST); |
220 | TFunction *newFunction = mSymbolTable->findUserDefinedFunction(function->name()); |
221 | TIntermSequence *newArguments = getStructSamplerArguments(function, node->getSequence()); |
222 | |
223 | TIntermAggregate *newCall = |
224 | TIntermAggregate::CreateFunctionCall(*newFunction, newArguments); |
225 | queueReplacement(newCall, OriginalNode::IS_DROPPED); |
226 | return true; |
227 | } |
228 | |
229 | private: |
230 | // This returns the name of a struct sampler reference. References are always TIntermBinary. |
231 | static ImmutableString GetStructSamplerNameFromTypedNode(TIntermTyped *node) |
232 | { |
233 | std::string stringBuilder; |
234 | |
235 | TIntermTyped *currentNode = node; |
236 | while (currentNode->getAsBinaryNode()) |
237 | { |
238 | TIntermBinary *asBinary = currentNode->getAsBinaryNode(); |
239 | |
240 | switch (asBinary->getOp()) |
241 | { |
242 | case EOpIndexDirect: |
243 | { |
244 | const int index = asBinary->getRight()->getAsConstantUnion()->getIConst(0); |
245 | const std::string strInt = Str(index); |
246 | stringBuilder.insert(0, strInt); |
247 | stringBuilder.insert(0, "_" ); |
248 | break; |
249 | } |
250 | case EOpIndexDirectStruct: |
251 | { |
252 | stringBuilder.insert(0, asBinary->getIndexStructFieldName().data()); |
253 | stringBuilder.insert(0, "_" ); |
254 | break; |
255 | } |
256 | |
257 | default: |
258 | UNREACHABLE(); |
259 | break; |
260 | } |
261 | |
262 | currentNode = asBinary->getLeft(); |
263 | } |
264 | |
265 | const ImmutableString &variableName = currentNode->getAsSymbolNode()->variable().name(); |
266 | stringBuilder.insert(0, variableName.data()); |
267 | |
268 | return stringBuilder; |
269 | } |
270 | |
271 | // Removes all the struct samplers from a struct specifier. |
272 | void stripStructSpecifierSamplers(const TStructure *structure, TIntermSequence *newSequence) |
273 | { |
274 | TFieldList *newFieldList = new TFieldList; |
275 | ASSERT(structure->containsSamplers()); |
276 | |
277 | for (const TField *field : structure->fields()) |
278 | { |
279 | const TType &fieldType = *field->type(); |
280 | if (!fieldType.isSampler() && !isRemovedStructType(fieldType)) |
281 | { |
282 | TType *newType = nullptr; |
283 | |
284 | if (fieldType.isStructureContainingSamplers()) |
285 | { |
286 | const TSymbol *structSymbol = |
287 | mSymbolTable->findUserDefined(fieldType.getStruct()->name()); |
288 | ASSERT(structSymbol && structSymbol->isStruct()); |
289 | const TStructure *fieldStruct = static_cast<const TStructure *>(structSymbol); |
290 | newType = new TType(fieldStruct, true); |
291 | if (fieldType.isArray()) |
292 | { |
293 | newType->makeArrays(*fieldType.getArraySizes()); |
294 | } |
295 | } |
296 | else |
297 | { |
298 | newType = new TType(fieldType); |
299 | } |
300 | |
301 | TField *newField = |
302 | new TField(newType, field->name(), field->line(), field->symbolType()); |
303 | newFieldList->push_back(newField); |
304 | } |
305 | } |
306 | |
307 | // Prune empty structs. |
308 | if (newFieldList->empty()) |
309 | { |
310 | mRemovedStructs.insert(structure->name()); |
311 | return; |
312 | } |
313 | |
314 | TStructure *newStruct = |
315 | new TStructure(mSymbolTable, structure->name(), newFieldList, structure->symbolType()); |
316 | TType *newStructType = new TType(newStruct, true); |
317 | TVariable *newStructVar = |
318 | new TVariable(mSymbolTable, kEmptyImmutableString, newStructType, SymbolType::Empty); |
319 | TIntermSymbol *newStructRef = new TIntermSymbol(newStructVar); |
320 | |
321 | TIntermDeclaration *structDecl = new TIntermDeclaration; |
322 | structDecl->appendDeclarator(newStructRef); |
323 | |
324 | newSequence->push_back(structDecl); |
325 | |
326 | mSymbolTable->declare(newStruct); |
327 | } |
328 | |
329 | // Returns true if the type is a struct that was removed because we extracted all the members. |
330 | bool isRemovedStructType(const TType &type) const |
331 | { |
332 | const TStructure *structure = type.getStruct(); |
333 | return (structure && (mRemovedStructs.count(structure->name()) > 0)); |
334 | } |
335 | |
336 | // Removes samplers from struct uniforms. For each sampler removed also adds a new globally |
337 | // defined sampler uniform. |
338 | void (TIntermDeclaration *oldDeclaration, |
339 | const TVariable &variable, |
340 | const TStructure *structure, |
341 | TIntermSequence *newSequence) |
342 | { |
343 | ASSERT(structure->containsSamplers()); |
344 | |
345 | size_t nonSamplerCount = 0; |
346 | |
347 | for (const TField *field : structure->fields()) |
348 | { |
349 | nonSamplerCount += |
350 | extractFieldSamplers(variable.name(), field, variable.getType(), newSequence); |
351 | } |
352 | |
353 | if (nonSamplerCount > 0) |
354 | { |
355 | // Keep the old declaration around if it has other members. |
356 | newSequence->push_back(oldDeclaration); |
357 | } |
358 | else |
359 | { |
360 | mRemovedUniformsCount++; |
361 | } |
362 | } |
363 | |
364 | // Extracts samplers from a field of a struct. Works with nested structs and arrays. |
365 | size_t (const ImmutableString &prefix, |
366 | const TField *field, |
367 | const TType &containingType, |
368 | TIntermSequence *newSequence) |
369 | { |
370 | if (containingType.isArray()) |
371 | { |
372 | size_t nonSamplerCount = 0; |
373 | |
374 | // Name the samplers internally as varName_<index>_fieldName |
375 | const TVector<unsigned int> &arraySizes = *containingType.getArraySizes(); |
376 | for (unsigned int arrayElement = 0; arrayElement < arraySizes[0]; ++arrayElement) |
377 | { |
378 | ImmutableStringBuilder stringBuilder(prefix.length() + kHexSize + 1); |
379 | stringBuilder << prefix << "_" ; |
380 | stringBuilder.appendHex(arrayElement); |
381 | nonSamplerCount = extractFieldSamplersImpl(stringBuilder, field, newSequence); |
382 | } |
383 | |
384 | return nonSamplerCount; |
385 | } |
386 | |
387 | return extractFieldSamplersImpl(prefix, field, newSequence); |
388 | } |
389 | |
390 | // Extracts samplers from a field of a struct. Works with nested structs and arrays. |
391 | size_t (const ImmutableString &prefix, |
392 | const TField *field, |
393 | TIntermSequence *newSequence) |
394 | { |
395 | size_t nonSamplerCount = 0; |
396 | |
397 | const TType &fieldType = *field->type(); |
398 | if (fieldType.isSampler() || fieldType.isStructureContainingSamplers()) |
399 | { |
400 | ImmutableStringBuilder stringBuilder(prefix.length() + field->name().length() + 1); |
401 | stringBuilder << prefix << "_" << field->name(); |
402 | ImmutableString newPrefix(stringBuilder); |
403 | |
404 | if (fieldType.isSampler()) |
405 | { |
406 | extractSampler(newPrefix, fieldType, newSequence); |
407 | } |
408 | else |
409 | { |
410 | const TStructure *structure = fieldType.getStruct(); |
411 | for (const TField *nestedField : structure->fields()) |
412 | { |
413 | nonSamplerCount += |
414 | extractFieldSamplers(newPrefix, nestedField, fieldType, newSequence); |
415 | } |
416 | } |
417 | } |
418 | else |
419 | { |
420 | nonSamplerCount++; |
421 | } |
422 | |
423 | return nonSamplerCount; |
424 | } |
425 | |
426 | // Extracts a sampler from a struct. Declares the new extracted sampler. |
427 | void (const ImmutableString &newName, |
428 | const TType &fieldType, |
429 | TIntermSequence *newSequence) const |
430 | { |
431 | TType *newType = new TType(fieldType); |
432 | newType->setQualifier(EvqUniform); |
433 | TVariable *newVariable = |
434 | new TVariable(mSymbolTable, newName, newType, SymbolType::AngleInternal); |
435 | TIntermSymbol *newRef = new TIntermSymbol(newVariable); |
436 | |
437 | TIntermDeclaration *samplerDecl = new TIntermDeclaration; |
438 | samplerDecl->appendDeclarator(newRef); |
439 | |
440 | newSequence->push_back(samplerDecl); |
441 | |
442 | mSymbolTable->declareInternal(newVariable); |
443 | } |
444 | |
445 | // Returns the chained name of a sampler uniform field. |
446 | static ImmutableString GetFieldName(const ImmutableString ¶mName, |
447 | const TField *field, |
448 | unsigned arrayIndex) |
449 | { |
450 | ImmutableStringBuilder nameBuilder(paramName.length() + kHexSize + 2 + |
451 | field->name().length()); |
452 | nameBuilder << paramName << "_" ; |
453 | |
454 | if (arrayIndex < std::numeric_limits<unsigned>::max()) |
455 | { |
456 | nameBuilder.appendHex(arrayIndex); |
457 | nameBuilder << "_" ; |
458 | } |
459 | nameBuilder << field->name(); |
460 | |
461 | return nameBuilder; |
462 | } |
463 | |
464 | // A pattern that visits every parameter of a function call. Uses different handlers for struct |
465 | // parameters, struct sampler parameters, and non-struct parameters. |
466 | class StructSamplerFunctionVisitor : angle::NonCopyable |
467 | { |
468 | public: |
469 | StructSamplerFunctionVisitor() = default; |
470 | virtual ~StructSamplerFunctionVisitor() = default; |
471 | |
472 | virtual void traverse(const TFunction *function) |
473 | { |
474 | size_t paramCount = function->getParamCount(); |
475 | |
476 | for (size_t paramIndex = 0; paramIndex < paramCount; ++paramIndex) |
477 | { |
478 | const TVariable *param = function->getParam(paramIndex); |
479 | const TType ¶mType = param->getType(); |
480 | |
481 | if (paramType.isStructureContainingSamplers()) |
482 | { |
483 | const ImmutableString &baseName = getNameFromIndex(function, paramIndex); |
484 | if (traverseStructContainingSamplers(baseName, paramType)) |
485 | { |
486 | visitStructParam(function, paramIndex); |
487 | } |
488 | } |
489 | else |
490 | { |
491 | visitNonStructParam(function, paramIndex); |
492 | } |
493 | } |
494 | } |
495 | |
496 | virtual ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) = 0; |
497 | virtual void visitSamplerInStructParam(const ImmutableString &name, |
498 | const TField *field) = 0; |
499 | virtual void visitStructParam(const TFunction *function, size_t paramIndex) = 0; |
500 | virtual void visitNonStructParam(const TFunction *function, size_t paramIndex) = 0; |
501 | |
502 | private: |
503 | bool traverseStructContainingSamplers(const ImmutableString &baseName, |
504 | const TType &structType) |
505 | { |
506 | bool hasNonSamplerFields = false; |
507 | const TStructure *structure = structType.getStruct(); |
508 | for (const TField *field : structure->fields()) |
509 | { |
510 | if (field->type()->isStructureContainingSamplers() || field->type()->isSampler()) |
511 | { |
512 | if (traverseSamplerInStruct(baseName, structType, field)) |
513 | { |
514 | hasNonSamplerFields = true; |
515 | } |
516 | } |
517 | else |
518 | { |
519 | hasNonSamplerFields = true; |
520 | } |
521 | } |
522 | return hasNonSamplerFields; |
523 | } |
524 | |
525 | bool traverseSamplerInStruct(const ImmutableString &baseName, |
526 | const TType &baseType, |
527 | const TField *field) |
528 | { |
529 | bool hasNonSamplerParams = false; |
530 | |
531 | if (baseType.isArray()) |
532 | { |
533 | const TVector<unsigned int> &arraySizes = *baseType.getArraySizes(); |
534 | ASSERT(arraySizes.size() == 1); |
535 | |
536 | for (unsigned int arrayIndex = 0; arrayIndex < arraySizes[0]; ++arrayIndex) |
537 | { |
538 | ImmutableString name = GetFieldName(baseName, field, arrayIndex); |
539 | |
540 | if (field->type()->isStructureContainingSamplers()) |
541 | { |
542 | if (traverseStructContainingSamplers(name, *field->type())) |
543 | { |
544 | hasNonSamplerParams = true; |
545 | } |
546 | } |
547 | else |
548 | { |
549 | ASSERT(field->type()->isSampler()); |
550 | visitSamplerInStructParam(name, field); |
551 | } |
552 | } |
553 | } |
554 | else if (field->type()->isStructureContainingSamplers()) |
555 | { |
556 | ImmutableString name = |
557 | GetFieldName(baseName, field, std::numeric_limits<unsigned>::max()); |
558 | hasNonSamplerParams = traverseStructContainingSamplers(name, *field->type()); |
559 | } |
560 | else |
561 | { |
562 | ASSERT(field->type()->isSampler()); |
563 | ImmutableString name = |
564 | GetFieldName(baseName, field, std::numeric_limits<unsigned>::max()); |
565 | visitSamplerInStructParam(name, field); |
566 | } |
567 | |
568 | return hasNonSamplerParams; |
569 | } |
570 | }; |
571 | |
572 | // A visitor that replaces functions with struct sampler references. The struct sampler |
573 | // references are expanded to include new fields for the structs. |
574 | class CreateStructSamplerFunctionVisitor final : public StructSamplerFunctionVisitor |
575 | { |
576 | public: |
577 | CreateStructSamplerFunctionVisitor(TSymbolTable *symbolTable) |
578 | : mSymbolTable(symbolTable), mNewFunction(nullptr) |
579 | {} |
580 | |
581 | ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) override |
582 | { |
583 | const TVariable *param = function->getParam(paramIndex); |
584 | return param->name(); |
585 | } |
586 | |
587 | void traverse(const TFunction *function) override |
588 | { |
589 | mNewFunction = |
590 | new TFunction(mSymbolTable, function->name(), function->symbolType(), |
591 | &function->getReturnType(), function->isKnownToNotHaveSideEffects()); |
592 | |
593 | StructSamplerFunctionVisitor::traverse(function); |
594 | } |
595 | |
596 | void visitSamplerInStructParam(const ImmutableString &name, const TField *field) override |
597 | { |
598 | TVariable *fieldSampler = |
599 | new TVariable(mSymbolTable, name, field->type(), SymbolType::AngleInternal); |
600 | mNewFunction->addParameter(fieldSampler); |
601 | mSymbolTable->declareInternal(fieldSampler); |
602 | } |
603 | |
604 | void visitStructParam(const TFunction *function, size_t paramIndex) override |
605 | { |
606 | const TVariable *param = function->getParam(paramIndex); |
607 | TType *structType = GetStructSamplerParameterType(mSymbolTable, *param); |
608 | TVariable *newParam = |
609 | new TVariable(mSymbolTable, param->name(), structType, param->symbolType()); |
610 | mNewFunction->addParameter(newParam); |
611 | } |
612 | |
613 | void visitNonStructParam(const TFunction *function, size_t paramIndex) override |
614 | { |
615 | const TVariable *param = function->getParam(paramIndex); |
616 | mNewFunction->addParameter(param); |
617 | } |
618 | |
619 | TFunction *getNewFunction() const { return mNewFunction; } |
620 | |
621 | private: |
622 | TSymbolTable *mSymbolTable; |
623 | TFunction *mNewFunction; |
624 | }; |
625 | |
626 | TFunction *createStructSamplerFunction(const TFunction *function) const |
627 | { |
628 | CreateStructSamplerFunctionVisitor visitor(mSymbolTable); |
629 | visitor.traverse(function); |
630 | return visitor.getNewFunction(); |
631 | } |
632 | |
633 | // A visitor that replaces function calls with expanded struct sampler parameters. |
634 | class GetSamplerArgumentsVisitor final : public StructSamplerFunctionVisitor |
635 | { |
636 | public: |
637 | GetSamplerArgumentsVisitor(TSymbolTable *symbolTable, const TIntermSequence *arguments) |
638 | : mSymbolTable(symbolTable), mArguments(arguments), mNewArguments(new TIntermSequence) |
639 | {} |
640 | |
641 | ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) override |
642 | { |
643 | TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped(); |
644 | return GetStructSamplerNameFromTypedNode(argument); |
645 | } |
646 | |
647 | void visitSamplerInStructParam(const ImmutableString &name, const TField *field) override |
648 | { |
649 | TVariable *argSampler = |
650 | new TVariable(mSymbolTable, name, field->type(), SymbolType::AngleInternal); |
651 | TIntermSymbol *argSymbol = new TIntermSymbol(argSampler); |
652 | mNewArguments->push_back(argSymbol); |
653 | } |
654 | |
655 | void visitStructParam(const TFunction *function, size_t paramIndex) override |
656 | { |
657 | // The tree structure of the parameter is modified to point to the new type. This leaves |
658 | // the tree in a consistent state. |
659 | TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped(); |
660 | TIntermTyped *replacement = ReplaceTypeOfTypedStructNode(argument, mSymbolTable); |
661 | mNewArguments->push_back(replacement); |
662 | } |
663 | |
664 | void visitNonStructParam(const TFunction *function, size_t paramIndex) override |
665 | { |
666 | TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped(); |
667 | mNewArguments->push_back(argument); |
668 | } |
669 | |
670 | TIntermSequence *getNewArguments() const { return mNewArguments; } |
671 | |
672 | private: |
673 | TSymbolTable *mSymbolTable; |
674 | const TIntermSequence *mArguments; |
675 | TIntermSequence *mNewArguments; |
676 | }; |
677 | |
678 | TIntermSequence *getStructSamplerArguments(const TFunction *function, |
679 | const TIntermSequence *arguments) const |
680 | { |
681 | GetSamplerArgumentsVisitor visitor(mSymbolTable, arguments); |
682 | visitor.traverse(function); |
683 | return visitor.getNewArguments(); |
684 | } |
685 | |
686 | int mRemovedUniformsCount; |
687 | std::set<ImmutableString> mRemovedStructs; |
688 | }; |
689 | } // anonymous namespace |
690 | |
691 | int RewriteStructSamplers(TIntermBlock *root, TSymbolTable *symbolTable) |
692 | { |
693 | Traverser rewriteStructSamplers(symbolTable); |
694 | root->traverse(&rewriteStructSamplers); |
695 | rewriteStructSamplers.updateTree(); |
696 | |
697 | return rewriteStructSamplers.removedUniformsCount(); |
698 | } |
699 | } // namespace sh |
700 | |