1/*
2 * Copyright (C) 2019 Apple Inc. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 *
13 * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23 * THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26#include "config.h"
27#include "WHLSLChecker.h"
28
29#if ENABLE(WEBGPU)
30
31#include "WHLSLArrayReferenceType.h"
32#include "WHLSLArrayType.h"
33#include "WHLSLAssignmentExpression.h"
34#include "WHLSLCallExpression.h"
35#include "WHLSLCommaExpression.h"
36#include "WHLSLDereferenceExpression.h"
37#include "WHLSLDoWhileLoop.h"
38#include "WHLSLDotExpression.h"
39#include "WHLSLForLoop.h"
40#include "WHLSLGatherEntryPointItems.h"
41#include "WHLSLIfStatement.h"
42#include "WHLSLIndexExpression.h"
43#include "WHLSLInferTypes.h"
44#include "WHLSLLogicalExpression.h"
45#include "WHLSLLogicalNotExpression.h"
46#include "WHLSLMakeArrayReferenceExpression.h"
47#include "WHLSLMakePointerExpression.h"
48#include "WHLSLPointerType.h"
49#include "WHLSLProgram.h"
50#include "WHLSLReadModifyWriteExpression.h"
51#include "WHLSLResolvableType.h"
52#include "WHLSLResolveOverloadImpl.h"
53#include "WHLSLResolvingType.h"
54#include "WHLSLReturn.h"
55#include "WHLSLSwitchStatement.h"
56#include "WHLSLTernaryExpression.h"
57#include "WHLSLVisitor.h"
58#include "WHLSLWhileLoop.h"
59#include <wtf/HashMap.h>
60#include <wtf/HashSet.h>
61#include <wtf/Ref.h>
62#include <wtf/Vector.h>
63#include <wtf/text/WTFString.h>
64
65namespace WebCore {
66
67namespace WHLSL {
68
69class PODChecker : public Visitor {
70public:
71 PODChecker() = default;
72
73 virtual ~PODChecker() = default;
74
75 void visit(AST::EnumerationDefinition& enumerationDefinition) override
76 {
77 Visitor::visit(enumerationDefinition);
78 }
79
80 void visit(AST::NativeTypeDeclaration& nativeTypeDeclaration) override
81 {
82 if (!nativeTypeDeclaration.isNumber()
83 && !nativeTypeDeclaration.isVector()
84 && !nativeTypeDeclaration.isMatrix())
85 setError();
86 }
87
88 void visit(AST::StructureDefinition& structureDefinition) override
89 {
90 Visitor::visit(structureDefinition);
91 }
92
93 void visit(AST::TypeDefinition& typeDefinition) override
94 {
95 Visitor::visit(typeDefinition);
96 }
97
98 void visit(AST::ArrayType& arrayType) override
99 {
100 Visitor::visit(arrayType);
101 }
102
103 void visit(AST::PointerType&) override
104 {
105 setError();
106 }
107
108 void visit(AST::ArrayReferenceType&) override
109 {
110 setError();
111 }
112
113 void visit(AST::TypeReference& typeReference) override
114 {
115 ASSERT(typeReference.resolvedType());
116 checkErrorAndVisit(*typeReference.resolvedType());
117 }
118};
119
120static AST::NativeFunctionDeclaration resolveWithOperatorAnderIndexer(AST::CallExpression& callExpression, AST::ArrayReferenceType& firstArgument, const Intrinsics& intrinsics)
121{
122 const bool isOperator = true;
123 auto returnType = makeUniqueRef<AST::PointerType>(Lexer::Token(callExpression.origin()), firstArgument.addressSpace(), firstArgument.elementType().clone());
124 AST::VariableDeclarations parameters;
125 parameters.append(AST::VariableDeclaration(Lexer::Token(callExpression.origin()), AST::Qualifiers(), { firstArgument.clone() }, String(), WTF::nullopt, WTF::nullopt));
126 parameters.append(AST::VariableDeclaration(Lexer::Token(callExpression.origin()), AST::Qualifiers(), { AST::TypeReference::wrap(Lexer::Token(callExpression.origin()), intrinsics.uintType()) }, String(), WTF::nullopt, WTF::nullopt));
127 return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(callExpression.origin()), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator&[]", String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator));
128}
129
130static AST::NativeFunctionDeclaration resolveWithOperatorLength(AST::CallExpression& callExpression, AST::UnnamedType& firstArgument, const Intrinsics& intrinsics)
131{
132 const bool isOperator = true;
133 auto returnType = AST::TypeReference::wrap(Lexer::Token(callExpression.origin()), intrinsics.uintType());
134 AST::VariableDeclarations parameters;
135 parameters.append(AST::VariableDeclaration(Lexer::Token(callExpression.origin()), AST::Qualifiers(), { firstArgument.clone() }, String(), WTF::nullopt, WTF::nullopt));
136 return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(callExpression.origin()), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator.length", String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator));
137}
138
139static AST::NativeFunctionDeclaration resolveWithReferenceComparator(AST::CallExpression& callExpression, ResolvingType& firstArgument, ResolvingType& secondArgument, const Intrinsics& intrinsics)
140{
141 const bool isOperator = true;
142 auto returnType = AST::TypeReference::wrap(Lexer::Token(callExpression.origin()), intrinsics.boolType());
143 auto argumentType = firstArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
144 return unnamedType->clone();
145 }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
146 return secondArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
147 return unnamedType->clone();
148 }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
149 // We encountered "null == null".
150 // FIXME: This can probably be generalized, using the "preferred type" infrastructure used by generic literals
151 ASSERT_NOT_REACHED();
152 return AST::TypeReference::wrap(Lexer::Token(callExpression.origin()), intrinsics.intType());
153 }));
154 }));
155 AST::VariableDeclarations parameters;
156 parameters.append(AST::VariableDeclaration(Lexer::Token(callExpression.origin()), AST::Qualifiers(), { argumentType->clone() }, String(), WTF::nullopt, WTF::nullopt));
157 parameters.append(AST::VariableDeclaration(Lexer::Token(callExpression.origin()), AST::Qualifiers(), { WTFMove(argumentType) }, String(), WTF::nullopt, WTF::nullopt));
158 return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(callExpression.origin()), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator==", String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator));
159}
160
161enum class Acceptability {
162 Yes,
163 Maybe,
164 No
165};
166
167static Optional<AST::NativeFunctionDeclaration> resolveByInstantiation(AST::CallExpression& callExpression, const Vector<std::reference_wrapper<ResolvingType>>& types, const Intrinsics& intrinsics)
168{
169 if (callExpression.name() == "operator&[]" && types.size() == 2) {
170 auto* firstArgumentArrayRef = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::ArrayReferenceType* {
171 if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)))
172 return &downcast<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType));
173 return nullptr;
174 }, [](RefPtr<ResolvableTypeReference>&) -> AST::ArrayReferenceType* {
175 return nullptr;
176 }));
177 bool secondArgumentIsUint = types[1].get().visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
178 return matches(unnamedType, intrinsics.uintType());
179 }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
180 return resolvableTypeReference->resolvableType().canResolve(intrinsics.uintType());
181 }));
182 if (firstArgumentArrayRef && secondArgumentIsUint)
183 return resolveWithOperatorAnderIndexer(callExpression, *firstArgumentArrayRef, intrinsics);
184 } else if (callExpression.name() == "operator.length" && types.size() == 1) {
185 auto* firstArgumentReference = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* {
186 if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)))
187 return &unnamedType;
188 return nullptr;
189 }, [](RefPtr<ResolvableTypeReference>&) -> AST::UnnamedType* {
190 return nullptr;
191 }));
192 if (firstArgumentReference)
193 return resolveWithOperatorLength(callExpression, *firstArgumentReference, intrinsics);
194 } else if (callExpression.name() == "operator==" && types.size() == 2) {
195 auto acceptability = [](ResolvingType& resolvingType) -> Acceptability {
196 return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> Acceptability {
197 return is<AST::ReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)) ? Acceptability::Yes : Acceptability::No;
198 }, [](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> Acceptability {
199 return is<AST::NullLiteralType>(resolvableTypeReference->resolvableType()) ? Acceptability::Maybe : Acceptability::No;
200 }));
201 };
202 auto leftAcceptability = acceptability(types[0].get());
203 auto rightAcceptability = acceptability(types[1].get());
204 bool success = false;
205 if (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Yes) {
206 auto& unnamedType1 = types[0].get().getUnnamedType();
207 auto& unnamedType2 = types[1].get().getUnnamedType();
208 success = matches(unnamedType1, unnamedType2);
209 } else if ((leftAcceptability == Acceptability::Maybe && rightAcceptability == Acceptability::Yes)
210 || (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Maybe))
211 success = true;
212 if (success)
213 return resolveWithReferenceComparator(callExpression, types[0].get(), types[1].get(), intrinsics);
214 }
215 return WTF::nullopt;
216}
217
218static bool checkSemantics(Vector<EntryPointItem>& inputItems, Vector<EntryPointItem>& outputItems, const Optional<AST::EntryPointType>& entryPointType, const Intrinsics& intrinsics)
219{
220 {
221 auto checkDuplicateSemantics = [&](const Vector<EntryPointItem>& items) -> bool {
222 for (size_t i = 0; i < items.size(); ++i) {
223 for (size_t j = i + 1; j < items.size(); ++j) {
224 if (items[i].semantic == items[j].semantic)
225 return false;
226 }
227 }
228 return true;
229 };
230 if (!checkDuplicateSemantics(inputItems))
231 return false;
232 if (!checkDuplicateSemantics(outputItems))
233 return false;
234 }
235
236 {
237 auto checkSemanticTypes = [&](const Vector<EntryPointItem>& items) -> bool {
238 for (auto& item : items) {
239 auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
240 return semantic.isAcceptableType(*item.unnamedType, intrinsics);
241 }), *item.semantic);
242 if (!acceptable)
243 return false;
244 }
245 return true;
246 };
247 if (!checkSemanticTypes(inputItems))
248 return false;
249 if (!checkSemanticTypes(outputItems))
250 return false;
251 }
252
253 {
254 auto checkSemanticForShaderType = [&](const Vector<EntryPointItem>& items, AST::BaseSemantic::ShaderItemDirection direction) -> bool {
255 for (auto& item : items) {
256 auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
257 return semantic.isAcceptableForShaderItemDirection(direction, entryPointType);
258 }), *item.semantic);
259 if (!acceptable)
260 return false;
261 }
262 return true;
263 };
264 if (!checkSemanticForShaderType(inputItems, AST::BaseSemantic::ShaderItemDirection::Input))
265 return false;
266 if (!checkSemanticForShaderType(outputItems, AST::BaseSemantic::ShaderItemDirection::Output))
267 return false;
268 }
269
270 {
271 auto checkPODData = [&](const Vector<EntryPointItem>& items) -> bool {
272 for (auto& item : items) {
273 PODChecker podChecker;
274 if (is<AST::PointerType>(item.unnamedType))
275 podChecker.checkErrorAndVisit(downcast<AST::PointerType>(*item.unnamedType).elementType());
276 else if (is<AST::ArrayReferenceType>(item.unnamedType))
277 podChecker.checkErrorAndVisit(downcast<AST::ArrayReferenceType>(*item.unnamedType).elementType());
278 else if (is<AST::ArrayType>(item.unnamedType))
279 podChecker.checkErrorAndVisit(downcast<AST::ArrayType>(*item.unnamedType).type());
280 else
281 continue;
282 if (podChecker.error())
283 return false;
284 }
285 return true;
286 };
287 if (!checkPODData(inputItems))
288 return false;
289 if (!checkPODData(outputItems))
290 return false;
291 }
292
293 return true;
294}
295
296static bool checkOperatorOverload(const AST::FunctionDefinition& functionDefinition, const Intrinsics& intrinsics, NameContext& nameContext)
297{
298 enum class CheckKind {
299 Index,
300 Dot
301 };
302
303 auto checkGetter = [&](CheckKind kind) -> bool {
304 size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1;
305 if (functionDefinition.parameters().size() != numExpectedParameters)
306 return false;
307 auto& firstParameterUnifyNode = (*functionDefinition.parameters()[0].type())->unifyNode();
308 if (is<AST::UnnamedType>(firstParameterUnifyNode)) {
309 auto& unnamedType = downcast<AST::UnnamedType>(firstParameterUnifyNode);
310 if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType))
311 return false;
312 }
313 if (kind == CheckKind::Index) {
314 auto& secondParameterUnifyNode = (*functionDefinition.parameters()[1].type())->unifyNode();
315 if (!is<AST::NamedType>(secondParameterUnifyNode))
316 return false;
317 auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode);
318 if (!is<AST::NativeTypeDeclaration>(namedType))
319 return false;
320 auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
321 if (!nativeTypeDeclaration.isInt())
322 return false;
323 }
324 return true;
325 };
326
327 auto checkSetter = [&](CheckKind kind) -> bool {
328 size_t numExpectedParameters = kind == CheckKind::Index ? 3 : 2;
329 if (functionDefinition.parameters().size() != numExpectedParameters)
330 return false;
331 auto& firstArgumentUnifyNode = (*functionDefinition.parameters()[0].type())->unifyNode();
332 if (is<AST::UnnamedType>(firstArgumentUnifyNode)) {
333 auto& unnamedType = downcast<AST::UnnamedType>(firstArgumentUnifyNode);
334 if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType))
335 return false;
336 }
337 if (kind == CheckKind::Index) {
338 auto& secondParameterUnifyNode = (*functionDefinition.parameters()[1].type())->unifyNode();
339 if (!is<AST::NamedType>(secondParameterUnifyNode))
340 return false;
341 auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode);
342 if (!is<AST::NativeTypeDeclaration>(namedType))
343 return false;
344 auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
345 if (!nativeTypeDeclaration.isInt())
346 return false;
347 }
348 if (!matches(functionDefinition.type(), *functionDefinition.parameters()[0].type()))
349 return false;
350 auto& valueType = *functionDefinition.parameters()[numExpectedParameters - 1].type();
351 auto getterName = functionDefinition.name().substring(0, functionDefinition.name().length() - 1);
352 auto* getterFuncs = nameContext.getFunctions(getterName);
353 if (!getterFuncs)
354 return false;
355 Vector<ResolvingType> argumentTypes;
356 Vector<std::reference_wrapper<ResolvingType>> argumentTypeReferences;
357 for (size_t i = 0; i < numExpectedParameters - 1; ++i)
358 argumentTypes.append((*functionDefinition.parameters()[0].type())->clone());
359 for (auto& argumentType : argumentTypes)
360 argumentTypeReferences.append(argumentType);
361 auto* overload = resolveFunctionOverloadImpl(*getterFuncs, argumentTypeReferences, nullptr);
362 if (!overload)
363 return false;
364 auto& resultType = overload->type();
365 return matches(resultType, valueType);
366 };
367
368 auto checkAnder = [&](CheckKind kind) -> bool {
369 size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1;
370 if (functionDefinition.parameters().size() != numExpectedParameters)
371 return false;
372 {
373 auto& unifyNode = functionDefinition.type().unifyNode();
374 if (!is<AST::UnnamedType>(unifyNode))
375 return false;
376 auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
377 if (!is<AST::PointerType>(unnamedType))
378 return false;
379 }
380 {
381 auto& unifyNode = (*functionDefinition.parameters()[0].type())->unifyNode();
382 if (!is<AST::UnnamedType>(unifyNode))
383 return false;
384 auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
385 return is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType);
386 }
387 };
388
389 if (!functionDefinition.isOperator())
390 return true;
391 if (functionDefinition.isCast())
392 return true;
393 if (functionDefinition.name() == "operator++" || functionDefinition.name() == "operator--") {
394 return functionDefinition.parameters().size() == 1
395 && matches(*functionDefinition.parameters()[0].type(), functionDefinition.type());
396 }
397 if (functionDefinition.name() == "operator+" || functionDefinition.name() == "operator-")
398 return functionDefinition.parameters().size() == 1 || functionDefinition.parameters().size() == 2;
399 if (functionDefinition.name() == "operator*"
400 || functionDefinition.name() == "operator/"
401 || functionDefinition.name() == "operator%"
402 || functionDefinition.name() == "operator&"
403 || functionDefinition.name() == "operator|"
404 || functionDefinition.name() == "operator^"
405 || functionDefinition.name() == "operator<<"
406 || functionDefinition.name() == "opreator>>")
407 return functionDefinition.parameters().size() == 2;
408 if (functionDefinition.name() == "operator~")
409 return functionDefinition.parameters().size() == 1;
410 if (functionDefinition.name() == "operator=="
411 || functionDefinition.name() == "operator<"
412 || functionDefinition.name() == "operator<="
413 || functionDefinition.name() == "operator>"
414 || functionDefinition.name() == "operator>=") {
415 return functionDefinition.parameters().size() == 2
416 && matches(functionDefinition.type(), intrinsics.boolType());
417 }
418 if (functionDefinition.name() == "operator[]")
419 return checkGetter(CheckKind::Index);
420 if (functionDefinition.name() == "operator[]=")
421 return checkSetter(CheckKind::Index);
422 if (functionDefinition.name() == "operator&[]")
423 return checkAnder(CheckKind::Index);
424 if (functionDefinition.name().startsWith("operator.")) {
425 if (functionDefinition.name().endsWith("="))
426 return checkSetter(CheckKind::Dot);
427 return checkGetter(CheckKind::Dot);
428 }
429 if (functionDefinition.name().startsWith("operator&."))
430 return checkAnder(CheckKind::Dot);
431 return false;
432}
433
434class Checker : public Visitor {
435public:
436 Checker(const Intrinsics& intrinsics, Program& program)
437 : m_intrinsics(intrinsics)
438 , m_program(program)
439 {
440 }
441
442 ~Checker() = default;
443
444 void visit(Program&) override;
445
446 bool assignTypes();
447
448private:
449 bool checkShaderType(const AST::FunctionDefinition&);
450 void finishVisitingPropertyAccess(AST::PropertyAccessExpression&, AST::UnnamedType& wrappedBaseType, AST::UnnamedType* extraArgumentType = nullptr);
451 bool isBoolType(ResolvingType&);
452 struct RecurseInfo {
453 ResolvingType& resolvingType;
454 Optional<AST::AddressSpace>& addressSpace;
455 };
456 Optional<RecurseInfo> recurseAndGetInfo(AST::Expression&, bool requiresLValue = false);
457 Optional<RecurseInfo> getInfo(AST::Expression&, bool requiresLValue = false);
458 Optional<UniqueRef<AST::UnnamedType>> recurseAndWrapBaseType(AST::PropertyAccessExpression&);
459 bool recurseAndRequireBoolType(AST::Expression&);
460 void assignType(AST::Expression&, UniqueRef<AST::UnnamedType>&&, Optional<AST::AddressSpace> = WTF::nullopt);
461 void assignType(AST::Expression&, RefPtr<ResolvableTypeReference>&&, Optional<AST::AddressSpace> = WTF::nullopt);
462 void forwardType(AST::Expression&, ResolvingType&, Optional<AST::AddressSpace> = WTF::nullopt);
463
464 void visit(AST::FunctionDefinition&) override;
465 void visit(AST::EnumerationDefinition&) override;
466 void visit(AST::TypeReference&) override;
467 void visit(AST::VariableDeclaration&) override;
468 void visit(AST::AssignmentExpression&) override;
469 void visit(AST::ReadModifyWriteExpression&) override;
470 void visit(AST::DereferenceExpression&) override;
471 void visit(AST::MakePointerExpression&) override;
472 void visit(AST::MakeArrayReferenceExpression&) override;
473 void visit(AST::DotExpression&) override;
474 void visit(AST::IndexExpression&) override;
475 void visit(AST::VariableReference&) override;
476 void visit(AST::Return&) override;
477 void visit(AST::PointerType&) override;
478 void visit(AST::ArrayReferenceType&) override;
479 void visit(AST::IntegerLiteral&) override;
480 void visit(AST::UnsignedIntegerLiteral&) override;
481 void visit(AST::FloatLiteral&) override;
482 void visit(AST::NullLiteral&) override;
483 void visit(AST::BooleanLiteral&) override;
484 void visit(AST::EnumerationMemberLiteral&) override;
485 void visit(AST::LogicalNotExpression&) override;
486 void visit(AST::LogicalExpression&) override;
487 void visit(AST::IfStatement&) override;
488 void visit(AST::WhileLoop&) override;
489 void visit(AST::DoWhileLoop&) override;
490 void visit(AST::ForLoop&) override;
491 void visit(AST::SwitchStatement&) override;
492 void visit(AST::CommaExpression&) override;
493 void visit(AST::TernaryExpression&) override;
494 void visit(AST::CallExpression&) override;
495
496 HashMap<AST::Expression*, ResolvingType> m_typeMap;
497 HashMap<AST::Expression*, Optional<AST::AddressSpace>> m_addressSpaceMap;
498 HashSet<String> m_vertexEntryPoints;
499 HashSet<String> m_fragmentEntryPoints;
500 HashSet<String> m_computeEntryPoints;
501 const Intrinsics& m_intrinsics;
502 Program& m_program;
503};
504
505void Checker::visit(Program& program)
506{
507 // These visiting functions might add new global statements, so don't use foreach syntax.
508 for (size_t i = 0; i < program.typeDefinitions().size(); ++i)
509 checkErrorAndVisit(program.typeDefinitions()[i]);
510 for (size_t i = 0; i < program.structureDefinitions().size(); ++i)
511 checkErrorAndVisit(program.structureDefinitions()[i]);
512 for (size_t i = 0; i < program.enumerationDefinitions().size(); ++i)
513 checkErrorAndVisit(program.enumerationDefinitions()[i]);
514 for (size_t i = 0; i < program.nativeTypeDeclarations().size(); ++i)
515 checkErrorAndVisit(program.nativeTypeDeclarations()[i]);
516
517 for (size_t i = 0; i < program.functionDefinitions().size(); ++i)
518 checkErrorAndVisit(program.functionDefinitions()[i]);
519 for (size_t i = 0; i < program.nativeFunctionDeclarations().size(); ++i)
520 checkErrorAndVisit(program.nativeFunctionDeclarations()[i]);
521}
522
523bool Checker::assignTypes()
524{
525 for (auto& keyValuePair : m_typeMap) {
526 auto success = keyValuePair.value.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
527 keyValuePair.key->setType(unnamedType->clone());
528 return true;
529 }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
530 if (!resolvableTypeReference->resolvableType().resolvedType()) {
531 if (!static_cast<bool>(commit(resolvableTypeReference->resolvableType())))
532 return false;
533 }
534 keyValuePair.key->setType(resolvableTypeReference->resolvableType().resolvedType()->clone());
535 return true;
536 }));
537 if (!success)
538 return false;
539 }
540
541 for (auto& keyValuePair : m_addressSpaceMap)
542 keyValuePair.key->setAddressSpace(keyValuePair.value);
543 return true;
544}
545
546bool Checker::checkShaderType(const AST::FunctionDefinition& functionDefinition)
547{
548 switch (*functionDefinition.entryPointType()) {
549 case AST::EntryPointType::Vertex:
550 return static_cast<bool>(m_vertexEntryPoints.add(functionDefinition.name()));
551 case AST::EntryPointType::Fragment:
552 return static_cast<bool>(m_fragmentEntryPoints.add(functionDefinition.name()));
553 case AST::EntryPointType::Compute:
554 return static_cast<bool>(m_computeEntryPoints.add(functionDefinition.name()));
555 }
556}
557
558void Checker::visit(AST::FunctionDefinition& functionDefinition)
559{
560 if (functionDefinition.entryPointType()) {
561 if (!checkShaderType(functionDefinition)) {
562 setError();
563 return;
564 }
565 auto entryPointItems = gatherEntryPointItems(m_intrinsics, functionDefinition);
566 if (!entryPointItems) {
567 setError();
568 return;
569 }
570 if (!checkSemantics(entryPointItems->inputs, entryPointItems->outputs, functionDefinition.entryPointType(), m_intrinsics)) {
571 setError();
572 return;
573 }
574 }
575 if (!checkOperatorOverload(functionDefinition, m_intrinsics, m_program.nameContext())) {
576 setError();
577 return;
578 }
579
580 Visitor::visit(functionDefinition);
581}
582
583static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& left, ResolvingType& right)
584{
585 return left.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
586 return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
587 if (matches(left, right))
588 return left->clone();
589 return WTF::nullopt;
590 }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
591 return matchAndCommit(left, right->resolvableType());
592 }));
593 }, [&](RefPtr<ResolvableTypeReference>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
594 return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
595 return matchAndCommit(right, left->resolvableType());
596 }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
597 return matchAndCommit(left->resolvableType(), right->resolvableType());
598 }));
599 }));
600}
601
602static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::UnnamedType& unnamedType)
603{
604 return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
605 if (matches(unnamedType, resolvingType))
606 return unnamedType.clone();
607 return WTF::nullopt;
608 }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
609 return matchAndCommit(unnamedType, resolvingType->resolvableType());
610 }));
611}
612
613static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::NamedType& namedType)
614{
615 return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
616 if (matches(resolvingType, namedType))
617 return resolvingType->clone();
618 return WTF::nullopt;
619 }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
620 return matchAndCommit(namedType, resolvingType->resolvableType());
621 }));
622}
623
624void Checker::visit(AST::EnumerationDefinition& enumerationDefinition)
625{
626 auto* baseType = ([&]() -> AST::NativeTypeDeclaration* {
627 checkErrorAndVisit(enumerationDefinition.type());
628 auto& baseType = enumerationDefinition.type().unifyNode();
629 if (!is<AST::NamedType>(baseType))
630 return nullptr;
631 auto& namedType = downcast<AST::NamedType>(baseType);
632 if (!is<AST::NativeTypeDeclaration>(namedType))
633 return nullptr;
634 auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
635 if (!nativeTypeDeclaration.isInt())
636 return nullptr;
637 return &nativeTypeDeclaration;
638 })();
639 if (!baseType) {
640 setError();
641 return;
642 }
643
644 auto enumerationMembers = enumerationDefinition.enumerationMembers();
645
646 for (auto& member : enumerationMembers) {
647 if (!member.get().value())
648 continue;
649
650 bool success = false;
651 member.get().value()->visit(WTF::makeVisitor([&](AST::Expression& value) {
652 auto valueInfo = recurseAndGetInfo(value);
653 if (!valueInfo)
654 return;
655 success = static_cast<bool>(matchAndCommit(valueInfo->resolvingType, *baseType));
656 }));
657 if (!success) {
658 setError();
659 return;
660 }
661 }
662
663 int64_t nextValue = 0;
664 for (auto& member : enumerationMembers) {
665 if (member.get().value()) {
666 int64_t value;
667 member.get().value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) {
668 value = integerLiteral.valueForSelectedType();
669 }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) {
670 value = unsignedIntegerLiteral.valueForSelectedType();
671 }, [&](auto&) {
672 ASSERT_NOT_REACHED();
673 }));
674 nextValue = baseType->successor()(value);
675 } else {
676 if (nextValue > std::numeric_limits<int>::max()) {
677 ASSERT(nextValue <= std::numeric_limits<unsigned>::max());
678 member.get().setValue(AST::ConstantExpression(AST::UnsignedIntegerLiteral(Lexer::Token(member.get().origin()), static_cast<unsigned>(nextValue))));
679 }
680 ASSERT(nextValue >= std::numeric_limits<int>::min());
681 member.get().setValue(AST::ConstantExpression(AST::IntegerLiteral(Lexer::Token(member.get().origin()), static_cast<int>(nextValue))));
682 nextValue = baseType->successor()(nextValue);
683 }
684 }
685
686 auto getValue = [&](AST::EnumerationMember& member) -> int64_t {
687 int64_t value;
688 ASSERT(member.value());
689 member.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) {
690 value = integerLiteral.value();
691 }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) {
692 value = unsignedIntegerLiteral.value();
693 }, [&](auto&) {
694 ASSERT_NOT_REACHED();
695 }));
696 return value;
697 };
698
699 for (size_t i = 0; i < enumerationMembers.size(); ++i) {
700 auto value = getValue(enumerationMembers[i].get());
701 for (size_t j = i + 1; j < enumerationMembers.size(); ++j) {
702 auto otherValue = getValue(enumerationMembers[j].get());
703 if (value == otherValue) {
704 setError();
705 return;
706 }
707 }
708 }
709
710 bool foundZero = false;
711 for (auto& member : enumerationMembers) {
712 if (!getValue(member.get())) {
713 foundZero = true;
714 break;
715 }
716 }
717 if (!foundZero) {
718 setError();
719 return;
720 }
721}
722
723void Checker::visit(AST::TypeReference& typeReference)
724{
725 ASSERT(typeReference.resolvedType());
726
727 for (auto& typeArgument : typeReference.typeArguments())
728 checkErrorAndVisit(typeArgument);
729}
730
731auto Checker::recurseAndGetInfo(AST::Expression& expression, bool requiresLValue) -> Optional<RecurseInfo>
732{
733 Visitor::visit(expression);
734 if (error())
735 return WTF::nullopt;
736 return getInfo(expression, requiresLValue);
737}
738
739auto Checker::getInfo(AST::Expression& expression, bool requiresLValue) -> Optional<RecurseInfo>
740{
741 auto typeIterator = m_typeMap.find(&expression);
742 ASSERT(typeIterator != m_typeMap.end());
743
744 auto addressSpaceIterator = m_addressSpaceMap.find(&expression);
745 ASSERT(addressSpaceIterator != m_addressSpaceMap.end());
746 if (requiresLValue && !addressSpaceIterator->value) {
747 setError();
748 return WTF::nullopt;
749 }
750 return {{ typeIterator->value, addressSpaceIterator->value }};
751}
752
753void Checker::visit(AST::VariableDeclaration& variableDeclaration)
754{
755 // ReadModifyWriteExpressions are the only place where anonymous variables exist,
756 // and that doesn't recurse on the anonymous variables, so we can assume the variable has a type.
757 checkErrorAndVisit(*variableDeclaration.type());
758 if (variableDeclaration.initializer()) {
759 auto& lhsType = *variableDeclaration.type();
760 auto initializerInfo = recurseAndGetInfo(*variableDeclaration.initializer());
761 if (!initializerInfo)
762 return;
763 if (!matchAndCommit(initializerInfo->resolvingType, lhsType)) {
764 setError();
765 return;
766 }
767 }
768}
769
770void Checker::assignType(AST::Expression& expression, UniqueRef<AST::UnnamedType>&& unnamedType, Optional<AST::AddressSpace> addressSpace)
771{
772 auto addResult = m_typeMap.add(&expression, WTFMove(unnamedType));
773 ASSERT_UNUSED(addResult, addResult.isNewEntry);
774 auto addressSpaceAddResult = m_addressSpaceMap.add(&expression, addressSpace);
775 ASSERT_UNUSED(addressSpaceAddResult, addressSpaceAddResult.isNewEntry);
776}
777
778void Checker::assignType(AST::Expression& expression, RefPtr<ResolvableTypeReference>&& resolvableTypeReference, Optional<AST::AddressSpace> addressSpace)
779{
780 auto addResult = m_typeMap.add(&expression, WTFMove(resolvableTypeReference));
781 ASSERT_UNUSED(addResult, addResult.isNewEntry);
782 auto addressSpaceAddResult = m_addressSpaceMap.add(&expression, addressSpace);
783 ASSERT_UNUSED(addressSpaceAddResult, addressSpaceAddResult.isNewEntry);
784}
785
786void Checker::visit(AST::AssignmentExpression& assignmentExpression)
787{
788 auto leftInfo = recurseAndGetInfo(assignmentExpression.left(), true);
789 if (!leftInfo)
790 return;
791
792 auto rightInfo = recurseAndGetInfo(assignmentExpression.right());
793 if (!rightInfo)
794 return;
795
796 auto resultType = matchAndCommit(leftInfo->resolvingType, rightInfo->resolvingType);
797 if (!resultType) {
798 setError();
799 return;
800 }
801
802 assignType(assignmentExpression, WTFMove(*resultType));
803}
804
805void Checker::forwardType(AST::Expression& expression, ResolvingType& resolvingType, Optional<AST::AddressSpace> addressSpace)
806{
807 resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& result) {
808 auto addResult = m_typeMap.add(&expression, result->clone());
809 ASSERT_UNUSED(addResult, addResult.isNewEntry);
810 }, [&](RefPtr<ResolvableTypeReference>& result) {
811 auto addResult = m_typeMap.add(&expression, result.copyRef());
812 ASSERT_UNUSED(addResult, addResult.isNewEntry);
813 }));
814 auto addressSpaceAddResult = m_addressSpaceMap.add(&expression, addressSpace);
815 ASSERT_UNUSED(addressSpaceAddResult, addressSpaceAddResult.isNewEntry);
816}
817
818void Checker::visit(AST::ReadModifyWriteExpression& readModifyWriteExpression)
819{
820 auto lValueInfo = recurseAndGetInfo(readModifyWriteExpression.lValue(), true);
821 if (!lValueInfo)
822 return;
823
824 // FIXME: Figure out what to do with the ReadModifyWriteExpression's AnonymousVariables.
825
826 auto newValueInfo = recurseAndGetInfo(*readModifyWriteExpression.newValueExpression());
827 if (!newValueInfo)
828 return;
829
830 if (!matchAndCommit(lValueInfo->resolvingType, newValueInfo->resolvingType)) {
831 setError();
832 return;
833 }
834
835 auto resultInfo = recurseAndGetInfo(*readModifyWriteExpression.resultExpression());
836 if (!resultInfo)
837 return;
838
839 forwardType(readModifyWriteExpression, resultInfo->resolvingType);
840}
841
842static AST::UnnamedType* getUnnamedType(ResolvingType& resolvingType)
843{
844 return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& type) -> AST::UnnamedType* {
845 return &type;
846 }, [](RefPtr<ResolvableTypeReference>& type) -> AST::UnnamedType* {
847 // FIXME: If the type isn't committed, should we just commit() it now?
848 return type->resolvableType().resolvedType();
849 }));
850}
851
852void Checker::visit(AST::DereferenceExpression& dereferenceExpression)
853{
854 auto pointerInfo = recurseAndGetInfo(dereferenceExpression.pointer());
855 if (!pointerInfo)
856 return;
857
858 auto* unnamedType = getUnnamedType(pointerInfo->resolvingType);
859
860 auto* pointerType = ([&](AST::UnnamedType* unnamedType) -> AST::PointerType* {
861 if (!unnamedType)
862 return nullptr;
863 auto& unifyNode = unnamedType->unifyNode();
864 if (!is<AST::UnnamedType>(unifyNode))
865 return nullptr;
866 auto& unnamedUnifyType = downcast<AST::UnnamedType>(unifyNode);
867 if (!is<AST::PointerType>(unnamedUnifyType))
868 return nullptr;
869 return &downcast<AST::PointerType>(unnamedUnifyType);
870 })(unnamedType);
871 if (!pointerType) {
872 setError();
873 return;
874 }
875
876 assignType(dereferenceExpression, pointerType->clone(), pointerType->addressSpace());
877}
878
879void Checker::visit(AST::MakePointerExpression& makePointerExpression)
880{
881 auto lValueInfo = recurseAndGetInfo(makePointerExpression.lValue(), true);
882 if (!lValueInfo)
883 return;
884
885 auto* lValueType = getUnnamedType(lValueInfo->resolvingType);
886 if (!lValueType) {
887 setError();
888 return;
889 }
890
891 assignType(makePointerExpression, makeUniqueRef<AST::PointerType>(Lexer::Token(makePointerExpression.origin()), *lValueInfo->addressSpace, lValueType->clone()));
892}
893
894void Checker::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
895{
896 auto lValueInfo = recurseAndGetInfo(makeArrayReferenceExpression.lValue());
897 if (!lValueInfo)
898 return;
899
900 auto* lValueType = getUnnamedType(lValueInfo->resolvingType);
901 if (!lValueType) {
902 setError();
903 return;
904 }
905
906 auto& unifyNode = lValueType->unifyNode();
907 if (is<AST::UnnamedType>(unifyNode)) {
908 auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
909 if (is<AST::PointerType>(unnamedType)) {
910 auto& pointerType = downcast<AST::PointerType>(unnamedType);
911 // FIXME: Save the fact that we're not targetting the item; we're targetting the item's inner element.
912 assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), pointerType.addressSpace(), pointerType.elementType().clone()));
913 return;
914 }
915
916 if (!lValueInfo->addressSpace) {
917 setError();
918 return;
919 }
920
921 if (is<AST::ArrayType>(unnamedType)) {
922 auto& arrayType = downcast<AST::ArrayType>(unnamedType);
923 // FIXME: Save the number of elements.
924 assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), *lValueInfo->addressSpace, arrayType.type().clone()));
925 return;
926 }
927 }
928
929 if (!lValueInfo->addressSpace) {
930 setError();
931 return;
932 }
933
934 assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), *lValueInfo->addressSpace, lValueType->clone()));
935}
936
937void Checker::finishVisitingPropertyAccess(AST::PropertyAccessExpression& propertyAccessExpression, AST::UnnamedType& wrappedBaseType, AST::UnnamedType* extraArgumentType)
938{
939 using OverloadResolution = std::tuple<AST::FunctionDeclaration*, AST::UnnamedType*>;
940
941 AST::FunctionDeclaration* getFunction;
942 AST::UnnamedType* getReturnType;
943 std::tie(getFunction, getReturnType) = ([&]() -> OverloadResolution {
944 ResolvingType getArgumentType1(wrappedBaseType.clone());
945 Optional<ResolvingType> getArgumentType2;
946 if (extraArgumentType)
947 getArgumentType2 = ResolvingType(extraArgumentType->clone());
948
949 Vector<std::reference_wrapper<ResolvingType>> getArgumentTypes;
950 getArgumentTypes.append(getArgumentType1);
951 if (getArgumentType2)
952 getArgumentTypes.append(*getArgumentType2);
953
954 auto* getFunction = resolveFunctionOverloadImpl(propertyAccessExpression.possibleGetOverloads(), getArgumentTypes, nullptr);
955 if (!getFunction)
956 return std::make_pair(nullptr, nullptr);
957 return std::make_pair(getFunction, &getFunction->type());
958 })();
959
960 AST::FunctionDeclaration* andFunction;
961 AST::UnnamedType* andReturnType;
962 std::tie(andFunction, andReturnType) = ([&]() -> OverloadResolution {
963 auto computeAndArgumentType = [&](AST::UnnamedType& unnamedType) -> Optional<ResolvingType> {
964 if (is<AST::ArrayReferenceType>(unnamedType))
965 return { unnamedType.clone() };
966 if (is<AST::ArrayType>(unnamedType))
967 return { ResolvingType(makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(propertyAccessExpression.origin()), AST::AddressSpace::Thread, downcast<AST::ArrayType>(unnamedType).type().clone())) };
968 if (is<AST::PointerType>(unnamedType))
969 return WTF::nullopt;
970 return { ResolvingType(makeUniqueRef<AST::PointerType>(Lexer::Token(propertyAccessExpression.origin()), AST::AddressSpace::Thread, downcast<AST::TypeReference>(unnamedType).clone())) };
971 };
972 auto computeAndReturnType = [&](AST::UnnamedType& unnamedType) -> AST::UnnamedType* {
973 if (is<AST::PointerType>(unnamedType))
974 return &downcast<AST::PointerType>(unnamedType).elementType();
975 return nullptr;
976 };
977
978 auto andArgumentType1 = computeAndArgumentType(wrappedBaseType);
979 if (!andArgumentType1)
980 return std::make_pair(nullptr, nullptr);
981 Optional<ResolvingType> andArgumentType2;
982 if (extraArgumentType)
983 andArgumentType2 = ResolvingType(extraArgumentType->clone());
984
985 Vector<std::reference_wrapper<ResolvingType>> andArgumentTypes;
986 andArgumentTypes.append(*andArgumentType1);
987 if (andArgumentType2)
988 andArgumentTypes.append(*andArgumentType2);
989
990 auto* andFunction = resolveFunctionOverloadImpl(propertyAccessExpression.possibleAndOverloads(), andArgumentTypes, nullptr);
991 if (!andFunction)
992 return std::make_pair(nullptr, nullptr);
993 return std::make_pair(andFunction, computeAndReturnType(andFunction->type()));
994 })();
995
996 if (!getReturnType && !andReturnType) {
997 setError();
998 return;
999 }
1000
1001 if (getReturnType && andReturnType && !matches(*getReturnType, *andReturnType)) {
1002 setError();
1003 return;
1004 }
1005
1006 AST::FunctionDeclaration* setFunction;
1007 AST::UnnamedType* setReturnType;
1008 std::tie(setFunction, setReturnType) = ([&]() -> OverloadResolution {
1009 ResolvingType setArgument1Type(wrappedBaseType.clone());
1010 Optional<ResolvingType> setArgumentType2;
1011 if (extraArgumentType)
1012 setArgumentType2 = ResolvingType(extraArgumentType->clone());
1013 ResolvingType setArgument3Type(getReturnType ? getReturnType->clone() : andReturnType->clone());
1014
1015 Vector<std::reference_wrapper<ResolvingType>> setArgumentTypes;
1016 setArgumentTypes.append(setArgument1Type);
1017 if (setArgumentType2)
1018 setArgumentTypes.append(*setArgumentType2);
1019 setArgumentTypes.append(setArgument3Type);
1020
1021 auto* setFunction = resolveFunctionOverloadImpl(propertyAccessExpression.possibleSetOverloads(), setArgumentTypes, nullptr);
1022 if (!setFunction)
1023 return std::make_pair(nullptr, nullptr);
1024 return std::make_pair(setFunction, &setFunction->type());
1025 })();
1026
1027 if (setFunction) {
1028 if (!matches(setFunction->type(), wrappedBaseType)) {
1029 setError();
1030 return;
1031 }
1032 }
1033
1034 Optional<AST::AddressSpace> addressSpace;
1035 if (getReturnType || andReturnType) {
1036 // FIXME: The reference compiler has "else if (!node.base.isLValue && !baseType.isArrayRef)",
1037 // but I don't understand why it exists. I haven't written it here, and I'll investigate
1038 // if we can remove it from the reference compiler.
1039 if (is<AST::ReferenceType>(wrappedBaseType))
1040 addressSpace = downcast<AST::ReferenceType>(wrappedBaseType).addressSpace();
1041 else {
1042 auto addressSpaceIterator = m_addressSpaceMap.find(&propertyAccessExpression.base());
1043 ASSERT(addressSpaceIterator != m_addressSpaceMap.end());
1044 if (addressSpaceIterator->value)
1045 addressSpace = *addressSpaceIterator->value;
1046 else {
1047 setError();
1048 return;
1049 }
1050 }
1051 }
1052
1053 // FIXME: Generate the call expressions
1054
1055 assignType(propertyAccessExpression, getReturnType ? getReturnType->clone() : andReturnType->clone(), addressSpace);
1056}
1057
1058Optional<UniqueRef<AST::UnnamedType>> Checker::recurseAndWrapBaseType(AST::PropertyAccessExpression& propertyAccessExpression)
1059{
1060 auto baseInfo = recurseAndGetInfo(propertyAccessExpression.base());
1061 if (!baseInfo)
1062 return WTF::nullopt;
1063
1064 auto* baseType = getUnnamedType(baseInfo->resolvingType);
1065 if (!baseType) {
1066 setError();
1067 return WTF::nullopt;
1068 }
1069 auto& baseUnifyNode = baseType->unifyNode();
1070 if (is<AST::UnnamedType>(baseUnifyNode))
1071 return downcast<AST::UnnamedType>(baseUnifyNode).clone();
1072 ASSERT(is<AST::NamedType>(baseUnifyNode));
1073 return { AST::TypeReference::wrap(Lexer::Token(propertyAccessExpression.origin()), downcast<AST::NamedType>(baseUnifyNode)) };
1074}
1075
1076void Checker::visit(AST::DotExpression& dotExpression)
1077{
1078 auto baseType = recurseAndWrapBaseType(dotExpression);
1079 if (!baseType)
1080 return;
1081
1082 finishVisitingPropertyAccess(dotExpression, *baseType);
1083}
1084
1085void Checker::visit(AST::IndexExpression& indexExpression)
1086{
1087 auto baseType = recurseAndWrapBaseType(indexExpression);
1088 if (!baseType)
1089 return;
1090
1091 auto indexInfo = recurseAndGetInfo(indexExpression.indexExpression());
1092 if (!indexInfo)
1093 return;
1094 auto indexExpressionType = getUnnamedType(indexInfo->resolvingType);
1095 if (!indexExpressionType) {
1096 setError();
1097 return;
1098 }
1099
1100 finishVisitingPropertyAccess(indexExpression, WTFMove(*baseType), indexExpressionType);
1101}
1102
1103void Checker::visit(AST::VariableReference& variableReference)
1104{
1105 ASSERT(variableReference.variable());
1106 ASSERT(variableReference.variable()->type());
1107
1108 Optional<AST::AddressSpace> addressSpace;
1109 if (!variableReference.variable()->isAnonymous())
1110 addressSpace = AST::AddressSpace::Thread;
1111 assignType(variableReference, variableReference.variable()->type()->clone(), addressSpace);
1112}
1113
1114void Checker::visit(AST::Return& returnStatement)
1115{
1116 ASSERT(returnStatement.function());
1117 if (returnStatement.value()) {
1118 auto valueInfo = recurseAndGetInfo(*returnStatement.value());
1119 if (!valueInfo)
1120 return;
1121 if (!matchAndCommit(valueInfo->resolvingType, returnStatement.function()->type()))
1122 setError();
1123 return;
1124 }
1125
1126 if (!matches(returnStatement.function()->type(), m_intrinsics.voidType()))
1127 setError();
1128}
1129
1130void Checker::visit(AST::PointerType&)
1131{
1132 // Following pointer types can cause infinite loops because of data structures
1133 // like linked lists.
1134 // FIXME: Make sure this function should be empty
1135}
1136
1137void Checker::visit(AST::ArrayReferenceType&)
1138{
1139 // Following array reference types can cause infinite loops because of data
1140 // structures like linked lists.
1141 // FIXME: Make sure this function should be empty
1142}
1143
1144void Checker::visit(AST::IntegerLiteral& integerLiteral)
1145{
1146 assignType(integerLiteral, adoptRef(*new ResolvableTypeReference(integerLiteral.type())));
1147}
1148
1149void Checker::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
1150{
1151 assignType(unsignedIntegerLiteral, adoptRef(*new ResolvableTypeReference(unsignedIntegerLiteral.type())));
1152}
1153
1154void Checker::visit(AST::FloatLiteral& floatLiteral)
1155{
1156 assignType(floatLiteral, adoptRef(*new ResolvableTypeReference(floatLiteral.type())));
1157}
1158
1159void Checker::visit(AST::NullLiteral& nullLiteral)
1160{
1161 assignType(nullLiteral, adoptRef(*new ResolvableTypeReference(nullLiteral.type())));
1162}
1163
1164void Checker::visit(AST::BooleanLiteral& booleanLiteral)
1165{
1166 assignType(booleanLiteral, AST::TypeReference::wrap(Lexer::Token(booleanLiteral.origin()), m_intrinsics.boolType()));
1167}
1168
1169void Checker::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
1170{
1171 ASSERT(enumerationMemberLiteral.enumerationDefinition());
1172 auto& enumerationDefinition = *enumerationMemberLiteral.enumerationDefinition();
1173 assignType(enumerationMemberLiteral, AST::TypeReference::wrap(Lexer::Token(enumerationMemberLiteral.origin()), enumerationDefinition));
1174}
1175
1176bool Checker::isBoolType(ResolvingType& resolvingType)
1177{
1178 return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> bool {
1179 return matches(left, m_intrinsics.boolType());
1180 }, [&](RefPtr<ResolvableTypeReference>& left) -> bool {
1181 return static_cast<bool>(matchAndCommit(m_intrinsics.boolType(), left->resolvableType()));
1182 }));
1183}
1184
1185bool Checker::recurseAndRequireBoolType(AST::Expression& expression)
1186{
1187 auto expressionInfo = recurseAndGetInfo(expression);
1188 if (!expressionInfo)
1189 return false;
1190 if (!isBoolType(expressionInfo->resolvingType)) {
1191 setError();
1192 return false;
1193 }
1194 return true;
1195}
1196
1197void Checker::visit(AST::LogicalNotExpression& logicalNotExpression)
1198{
1199 if (!recurseAndRequireBoolType(logicalNotExpression.operand()))
1200 return;
1201 assignType(logicalNotExpression, AST::TypeReference::wrap(Lexer::Token(logicalNotExpression.origin()), m_intrinsics.boolType()));
1202}
1203
1204void Checker::visit(AST::LogicalExpression& logicalExpression)
1205{
1206 if (!recurseAndRequireBoolType(logicalExpression.left()))
1207 return;
1208 if (!recurseAndRequireBoolType(logicalExpression.right()))
1209 return;
1210 assignType(logicalExpression, AST::TypeReference::wrap(Lexer::Token(logicalExpression.origin()), m_intrinsics.boolType()));
1211}
1212
1213void Checker::visit(AST::IfStatement& ifStatement)
1214{
1215 if (!recurseAndRequireBoolType(ifStatement.conditional()))
1216 return;
1217 checkErrorAndVisit(ifStatement.body());
1218 if (ifStatement.elseBody())
1219 checkErrorAndVisit(*ifStatement.elseBody());
1220}
1221
1222void Checker::visit(AST::WhileLoop& whileLoop)
1223{
1224 if (!recurseAndRequireBoolType(whileLoop.conditional()))
1225 return;
1226 checkErrorAndVisit(whileLoop.body());
1227}
1228
1229void Checker::visit(AST::DoWhileLoop& doWhileLoop)
1230{
1231 checkErrorAndVisit(doWhileLoop.body());
1232 recurseAndRequireBoolType(doWhileLoop.conditional());
1233}
1234
1235void Checker::visit(AST::ForLoop& forLoop)
1236{
1237 WTF::visit(WTF::makeVisitor([&](AST::VariableDeclarationsStatement& variableDeclarationsStatement) {
1238 checkErrorAndVisit(variableDeclarationsStatement);
1239 }, [&](UniqueRef<AST::Expression>& expression) {
1240 checkErrorAndVisit(expression);
1241 }), forLoop.initialization());
1242 if (error())
1243 return;
1244 if (forLoop.condition()) {
1245 if (!recurseAndRequireBoolType(*forLoop.condition()))
1246 return;
1247 }
1248 if (forLoop.increment())
1249 checkErrorAndVisit(*forLoop.increment());
1250 checkErrorAndVisit(forLoop.body());
1251}
1252
1253void Checker::visit(AST::SwitchStatement& switchStatement)
1254{
1255 auto* valueType = ([&]() -> AST::NamedType* {
1256 auto valueInfo = recurseAndGetInfo(switchStatement.value());
1257 if (!valueInfo)
1258 return nullptr;
1259 auto* valueType = getUnnamedType(valueInfo->resolvingType);
1260 if (!valueType)
1261 return nullptr;
1262 auto& valueUnifyNode = valueType->unifyNode();
1263 if (!is<AST::NamedType>(valueUnifyNode))
1264 return nullptr;
1265 auto& valueNamedUnifyNode = downcast<AST::NamedType>(valueUnifyNode);
1266 if (!(is<AST::NativeTypeDeclaration>(valueNamedUnifyNode) && downcast<AST::NativeTypeDeclaration>(valueNamedUnifyNode).isInt())
1267 && !is<AST::EnumerationDefinition>(valueNamedUnifyNode))
1268 return nullptr;
1269 return &valueNamedUnifyNode;
1270 })();
1271 if (!valueType) {
1272 setError();
1273 return;
1274 }
1275
1276 bool hasDefault = false;
1277 for (auto& switchCase : switchStatement.switchCases()) {
1278 checkErrorAndVisit(switchCase.block());
1279 if (!switchCase.value()) {
1280 hasDefault = true;
1281 continue;
1282 }
1283 bool success;
1284 switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) {
1285 success = static_cast<bool>(matchAndCommit(*valueType, integerLiteral.type()));
1286 }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) {
1287 success = static_cast<bool>(matchAndCommit(*valueType, unsignedIntegerLiteral.type()));
1288 }, [&](AST::FloatLiteral& floatLiteral) {
1289 success = static_cast<bool>(matchAndCommit(*valueType, floatLiteral.type()));
1290 }, [&](AST::NullLiteral& nullLiteral) {
1291 success = static_cast<bool>(matchAndCommit(*valueType, nullLiteral.type()));
1292 }, [&](AST::BooleanLiteral&) {
1293 success = matches(*valueType, m_intrinsics.boolType());
1294 }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) {
1295 ASSERT(enumerationMemberLiteral.enumerationDefinition());
1296 success = matches(*valueType, *enumerationMemberLiteral.enumerationDefinition());
1297 }));
1298 if (!success) {
1299 setError();
1300 return;
1301 }
1302 }
1303
1304 for (size_t i = 0; i < switchStatement.switchCases().size(); ++i) {
1305 auto& firstCase = switchStatement.switchCases()[i];
1306 for (size_t j = i + 1; j < switchStatement.switchCases().size(); ++j) {
1307 auto& secondCase = switchStatement.switchCases()[j];
1308
1309 if (static_cast<bool>(firstCase.value()) != static_cast<bool>(secondCase.value()))
1310 continue;
1311
1312 if (!static_cast<bool>(firstCase.value())) {
1313 setError();
1314 return;
1315 }
1316
1317 bool success = true;
1318 firstCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& firstIntegerLiteral) {
1319 secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) {
1320 success = firstIntegerLiteral.value() != secondIntegerLiteral.value();
1321 }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) {
1322 success = static_cast<int64_t>(firstIntegerLiteral.value()) != static_cast<int64_t>(secondUnsignedIntegerLiteral.value());
1323 }, [](auto&) {
1324 }));
1325 }, [&](AST::UnsignedIntegerLiteral& firstUnsignedIntegerLiteral) {
1326 secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) {
1327 success = static_cast<int64_t>(firstUnsignedIntegerLiteral.value()) != static_cast<int64_t>(secondIntegerLiteral.value());
1328 }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) {
1329 success = firstUnsignedIntegerLiteral.value() != secondUnsignedIntegerLiteral.value();
1330 }, [](auto&) {
1331 }));
1332 }, [&](AST::EnumerationMemberLiteral& firstEnumerationMemberLiteral) {
1333 secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral&) {
1334 }, [&](AST::EnumerationMemberLiteral& secondEnumerationMemberLiteral) {
1335 ASSERT(firstEnumerationMemberLiteral.enumerationMember());
1336 ASSERT(secondEnumerationMemberLiteral.enumerationMember());
1337 success = firstEnumerationMemberLiteral.enumerationMember() != secondEnumerationMemberLiteral.enumerationMember();
1338 }, [](auto&) {
1339 }));
1340 }, [](auto&) {
1341 }));
1342 }
1343 }
1344
1345 if (!hasDefault) {
1346 if (is<AST::NativeTypeDeclaration>(*valueType)) {
1347 HashSet<int64_t> values;
1348 bool zeroValueExists;
1349 for (auto& switchCase : switchStatement.switchCases()) {
1350 int64_t value;
1351 switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) {
1352 value = integerLiteral.valueForSelectedType();
1353 }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) {
1354 value = unsignedIntegerLiteral.valueForSelectedType();
1355 }, [](auto&) {
1356 ASSERT_NOT_REACHED();
1357 }));
1358 if (!value)
1359 zeroValueExists = true;
1360 else
1361 values.add(value);
1362 }
1363 bool success = true;
1364 downcast<AST::NativeTypeDeclaration>(*valueType).iterateAllValues([&](int64_t value) -> bool {
1365 if (!value) {
1366 if (!zeroValueExists) {
1367 success = false;
1368 return true;
1369 }
1370 return false;
1371 }
1372 if (!values.contains(value)) {
1373 success = false;
1374 return true;
1375 }
1376 return false;
1377 });
1378 if (!success) {
1379 setError();
1380 return;
1381 }
1382 } else {
1383 ASSERT(is<AST::EnumerationDefinition>(*valueType));
1384 HashSet<AST::EnumerationMember*> values;
1385 for (auto& switchCase : switchStatement.switchCases()) {
1386 switchCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) {
1387 ASSERT(enumerationMemberLiteral.enumerationMember());
1388 values.add(enumerationMemberLiteral.enumerationMember());
1389 }, [](auto&) {
1390 ASSERT_NOT_REACHED();
1391 }));
1392 }
1393 for (auto& enumerationMember : downcast<AST::EnumerationDefinition>(*valueType).enumerationMembers()) {
1394 if (!values.contains(&enumerationMember.get())) {
1395 setError();
1396 return;
1397 }
1398 }
1399 }
1400 }
1401}
1402
1403void Checker::visit(AST::CommaExpression& commaExpression)
1404{
1405 ASSERT(commaExpression.list().size() > 0);
1406 Visitor::visit(commaExpression);
1407 if (error())
1408 return;
1409 auto lastInfo = getInfo(commaExpression.list().last());
1410 forwardType(commaExpression, lastInfo->resolvingType);
1411}
1412
1413void Checker::visit(AST::TernaryExpression& ternaryExpression)
1414{
1415 auto predicateInfo = recurseAndRequireBoolType(ternaryExpression.predicate());
1416 if (!predicateInfo)
1417 return;
1418
1419 auto bodyInfo = recurseAndGetInfo(ternaryExpression.bodyExpression());
1420 auto elseInfo = recurseAndGetInfo(ternaryExpression.elseExpression());
1421
1422 auto resultType = matchAndCommit(bodyInfo->resolvingType, elseInfo->resolvingType);
1423 if (!resultType) {
1424 setError();
1425 return;
1426 }
1427
1428 assignType(ternaryExpression, WTFMove(*resultType));
1429}
1430
1431void Checker::visit(AST::CallExpression& callExpression)
1432{
1433 Vector<std::reference_wrapper<ResolvingType>> types;
1434 types.reserveInitialCapacity(callExpression.arguments().size());
1435 for (auto& argument : callExpression.arguments()) {
1436 auto argumentInfo = recurseAndGetInfo(argument);
1437 if (!argumentInfo)
1438 return;
1439 types.uncheckedAppend(argumentInfo->resolvingType);
1440 }
1441 // Don't recurse on the castReturnType, because it's guaranteed to be a NamedType, which will get visited later.
1442 // We don't want to recurse to the same node twice.
1443
1444 ASSERT(callExpression.hasOverloads());
1445 auto* function = resolveFunctionOverloadImpl(*callExpression.overloads(), types, callExpression.castReturnType());
1446 if (!function) {
1447 if (auto newFunction = resolveByInstantiation(callExpression, types, m_intrinsics)) {
1448 m_program.append(WTFMove(*newFunction));
1449 function = &m_program.nativeFunctionDeclarations().last();
1450 }
1451 }
1452
1453 if (!function) {
1454 setError();
1455 return;
1456 }
1457
1458 for (size_t i = 0; i < function->parameters().size(); ++i) {
1459 if (!matchAndCommit(types[i].get(), *function->parameters()[i].type())) {
1460 setError();
1461 return;
1462 }
1463 }
1464
1465 callExpression.setFunction(*function);
1466
1467 assignType(callExpression, function->type().clone());
1468}
1469
1470bool check(Program& program)
1471{
1472 Checker checker(program.intrinsics(), program);
1473 checker.checkErrorAndVisit(program);
1474 if (checker.error())
1475 return false;
1476 return checker.assignTypes();
1477}
1478
1479} // namespace WHLSL
1480
1481} // namespace WebCore
1482
1483#endif // ENABLE(WEBGPU)
1484