1 | /* |
2 | * Copyright (C) 2019 Apple Inc. All rights reserved. |
3 | * |
4 | * Redistribution and use in source and binary forms, with or without |
5 | * modification, are permitted provided that the following conditions |
6 | * are met: |
7 | * 1. Redistributions of source code must retain the above copyright |
8 | * notice, this list of conditions and the following disclaimer. |
9 | * 2. Redistributions in binary form must reproduce the above copyright |
10 | * notice, this list of conditions and the following disclaimer in the |
11 | * documentation and/or other materials provided with the distribution. |
12 | * |
13 | * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS'' |
14 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, |
15 | * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR |
16 | * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS |
17 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
18 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
19 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
20 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
21 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
22 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF |
23 | * THE POSSIBILITY OF SUCH DAMAGE. |
24 | */ |
25 | |
26 | #include "config.h" |
27 | #include "WHLSLChecker.h" |
28 | |
29 | #if ENABLE(WEBGPU) |
30 | |
31 | #include "WHLSLArrayReferenceType.h" |
32 | #include "WHLSLArrayType.h" |
33 | #include "WHLSLAssignmentExpression.h" |
34 | #include "WHLSLCallExpression.h" |
35 | #include "WHLSLCommaExpression.h" |
36 | #include "WHLSLDereferenceExpression.h" |
37 | #include "WHLSLDoWhileLoop.h" |
38 | #include "WHLSLDotExpression.h" |
39 | #include "WHLSLForLoop.h" |
40 | #include "WHLSLGatherEntryPointItems.h" |
41 | #include "WHLSLIfStatement.h" |
42 | #include "WHLSLIndexExpression.h" |
43 | #include "WHLSLInferTypes.h" |
44 | #include "WHLSLLogicalExpression.h" |
45 | #include "WHLSLLogicalNotExpression.h" |
46 | #include "WHLSLMakeArrayReferenceExpression.h" |
47 | #include "WHLSLMakePointerExpression.h" |
48 | #include "WHLSLPointerType.h" |
49 | #include "WHLSLProgram.h" |
50 | #include "WHLSLReadModifyWriteExpression.h" |
51 | #include "WHLSLResolvableType.h" |
52 | #include "WHLSLResolveOverloadImpl.h" |
53 | #include "WHLSLResolvingType.h" |
54 | #include "WHLSLReturn.h" |
55 | #include "WHLSLSwitchStatement.h" |
56 | #include "WHLSLTernaryExpression.h" |
57 | #include "WHLSLVisitor.h" |
58 | #include "WHLSLWhileLoop.h" |
59 | #include <wtf/HashMap.h> |
60 | #include <wtf/HashSet.h> |
61 | #include <wtf/Ref.h> |
62 | #include <wtf/Vector.h> |
63 | #include <wtf/text/WTFString.h> |
64 | |
65 | namespace WebCore { |
66 | |
67 | namespace WHLSL { |
68 | |
69 | class PODChecker : public Visitor { |
70 | public: |
71 | PODChecker() = default; |
72 | |
73 | virtual ~PODChecker() = default; |
74 | |
75 | void visit(AST::EnumerationDefinition& enumerationDefinition) override |
76 | { |
77 | Visitor::visit(enumerationDefinition); |
78 | } |
79 | |
80 | void visit(AST::NativeTypeDeclaration& nativeTypeDeclaration) override |
81 | { |
82 | if (!nativeTypeDeclaration.isNumber() |
83 | && !nativeTypeDeclaration.isVector() |
84 | && !nativeTypeDeclaration.isMatrix()) |
85 | setError(); |
86 | } |
87 | |
88 | void visit(AST::StructureDefinition& structureDefinition) override |
89 | { |
90 | Visitor::visit(structureDefinition); |
91 | } |
92 | |
93 | void visit(AST::TypeDefinition& typeDefinition) override |
94 | { |
95 | Visitor::visit(typeDefinition); |
96 | } |
97 | |
98 | void visit(AST::ArrayType& arrayType) override |
99 | { |
100 | Visitor::visit(arrayType); |
101 | } |
102 | |
103 | void visit(AST::PointerType&) override |
104 | { |
105 | setError(); |
106 | } |
107 | |
108 | void visit(AST::ArrayReferenceType&) override |
109 | { |
110 | setError(); |
111 | } |
112 | |
113 | void visit(AST::TypeReference& typeReference) override |
114 | { |
115 | ASSERT(typeReference.resolvedType()); |
116 | checkErrorAndVisit(*typeReference.resolvedType()); |
117 | } |
118 | }; |
119 | |
120 | static AST::NativeFunctionDeclaration resolveWithOperatorAnderIndexer(AST::CallExpression& callExpression, AST::ArrayReferenceType& firstArgument, const Intrinsics& intrinsics) |
121 | { |
122 | const bool isOperator = true; |
123 | auto returnType = makeUniqueRef<AST::PointerType>(Lexer::Token(callExpression.origin()), firstArgument.addressSpace(), firstArgument.elementType().clone()); |
124 | AST::VariableDeclarations parameters; |
125 | parameters.append(AST::VariableDeclaration(Lexer::Token(callExpression.origin()), AST::Qualifiers(), { firstArgument.clone() }, String(), WTF::nullopt, WTF::nullopt)); |
126 | parameters.append(AST::VariableDeclaration(Lexer::Token(callExpression.origin()), AST::Qualifiers(), { AST::TypeReference::wrap(Lexer::Token(callExpression.origin()), intrinsics.uintType()) }, String(), WTF::nullopt, WTF::nullopt)); |
127 | return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(callExpression.origin()), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator&[]" , String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator)); |
128 | } |
129 | |
130 | static AST::NativeFunctionDeclaration resolveWithOperatorLength(AST::CallExpression& callExpression, AST::UnnamedType& firstArgument, const Intrinsics& intrinsics) |
131 | { |
132 | const bool isOperator = true; |
133 | auto returnType = AST::TypeReference::wrap(Lexer::Token(callExpression.origin()), intrinsics.uintType()); |
134 | AST::VariableDeclarations parameters; |
135 | parameters.append(AST::VariableDeclaration(Lexer::Token(callExpression.origin()), AST::Qualifiers(), { firstArgument.clone() }, String(), WTF::nullopt, WTF::nullopt)); |
136 | return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(callExpression.origin()), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator.length" , String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator)); |
137 | } |
138 | |
139 | static AST::NativeFunctionDeclaration resolveWithReferenceComparator(AST::CallExpression& callExpression, ResolvingType& firstArgument, ResolvingType& secondArgument, const Intrinsics& intrinsics) |
140 | { |
141 | const bool isOperator = true; |
142 | auto returnType = AST::TypeReference::wrap(Lexer::Token(callExpression.origin()), intrinsics.boolType()); |
143 | auto argumentType = firstArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> { |
144 | return unnamedType->clone(); |
145 | }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> { |
146 | return secondArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> { |
147 | return unnamedType->clone(); |
148 | }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> { |
149 | // We encountered "null == null". |
150 | // FIXME: This can probably be generalized, using the "preferred type" infrastructure used by generic literals |
151 | ASSERT_NOT_REACHED(); |
152 | return AST::TypeReference::wrap(Lexer::Token(callExpression.origin()), intrinsics.intType()); |
153 | })); |
154 | })); |
155 | AST::VariableDeclarations parameters; |
156 | parameters.append(AST::VariableDeclaration(Lexer::Token(callExpression.origin()), AST::Qualifiers(), { argumentType->clone() }, String(), WTF::nullopt, WTF::nullopt)); |
157 | parameters.append(AST::VariableDeclaration(Lexer::Token(callExpression.origin()), AST::Qualifiers(), { WTFMove(argumentType) }, String(), WTF::nullopt, WTF::nullopt)); |
158 | return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(callExpression.origin()), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator==" , String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator)); |
159 | } |
160 | |
161 | enum class Acceptability { |
162 | Yes, |
163 | Maybe, |
164 | No |
165 | }; |
166 | |
167 | static Optional<AST::NativeFunctionDeclaration> resolveByInstantiation(AST::CallExpression& callExpression, const Vector<std::reference_wrapper<ResolvingType>>& types, const Intrinsics& intrinsics) |
168 | { |
169 | if (callExpression.name() == "operator&[]" && types.size() == 2) { |
170 | auto* firstArgumentArrayRef = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::ArrayReferenceType* { |
171 | if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType))) |
172 | return &downcast<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)); |
173 | return nullptr; |
174 | }, [](RefPtr<ResolvableTypeReference>&) -> AST::ArrayReferenceType* { |
175 | return nullptr; |
176 | })); |
177 | bool secondArgumentIsUint = types[1].get().visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool { |
178 | return matches(unnamedType, intrinsics.uintType()); |
179 | }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool { |
180 | return resolvableTypeReference->resolvableType().canResolve(intrinsics.uintType()); |
181 | })); |
182 | if (firstArgumentArrayRef && secondArgumentIsUint) |
183 | return resolveWithOperatorAnderIndexer(callExpression, *firstArgumentArrayRef, intrinsics); |
184 | } else if (callExpression.name() == "operator.length" && types.size() == 1) { |
185 | auto* firstArgumentReference = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* { |
186 | if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType))) |
187 | return &unnamedType; |
188 | return nullptr; |
189 | }, [](RefPtr<ResolvableTypeReference>&) -> AST::UnnamedType* { |
190 | return nullptr; |
191 | })); |
192 | if (firstArgumentReference) |
193 | return resolveWithOperatorLength(callExpression, *firstArgumentReference, intrinsics); |
194 | } else if (callExpression.name() == "operator==" && types.size() == 2) { |
195 | auto acceptability = [](ResolvingType& resolvingType) -> Acceptability { |
196 | return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> Acceptability { |
197 | return is<AST::ReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)) ? Acceptability::Yes : Acceptability::No; |
198 | }, [](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> Acceptability { |
199 | return is<AST::NullLiteralType>(resolvableTypeReference->resolvableType()) ? Acceptability::Maybe : Acceptability::No; |
200 | })); |
201 | }; |
202 | auto leftAcceptability = acceptability(types[0].get()); |
203 | auto rightAcceptability = acceptability(types[1].get()); |
204 | bool success = false; |
205 | if (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Yes) { |
206 | auto& unnamedType1 = types[0].get().getUnnamedType(); |
207 | auto& unnamedType2 = types[1].get().getUnnamedType(); |
208 | success = matches(unnamedType1, unnamedType2); |
209 | } else if ((leftAcceptability == Acceptability::Maybe && rightAcceptability == Acceptability::Yes) |
210 | || (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Maybe)) |
211 | success = true; |
212 | if (success) |
213 | return resolveWithReferenceComparator(callExpression, types[0].get(), types[1].get(), intrinsics); |
214 | } |
215 | return WTF::nullopt; |
216 | } |
217 | |
218 | static bool checkSemantics(Vector<EntryPointItem>& inputItems, Vector<EntryPointItem>& outputItems, const Optional<AST::EntryPointType>& entryPointType, const Intrinsics& intrinsics) |
219 | { |
220 | { |
221 | auto checkDuplicateSemantics = [&](const Vector<EntryPointItem>& items) -> bool { |
222 | for (size_t i = 0; i < items.size(); ++i) { |
223 | for (size_t j = i + 1; j < items.size(); ++j) { |
224 | if (items[i].semantic == items[j].semantic) |
225 | return false; |
226 | } |
227 | } |
228 | return true; |
229 | }; |
230 | if (!checkDuplicateSemantics(inputItems)) |
231 | return false; |
232 | if (!checkDuplicateSemantics(outputItems)) |
233 | return false; |
234 | } |
235 | |
236 | { |
237 | auto checkSemanticTypes = [&](const Vector<EntryPointItem>& items) -> bool { |
238 | for (auto& item : items) { |
239 | auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool { |
240 | return semantic.isAcceptableType(*item.unnamedType, intrinsics); |
241 | }), *item.semantic); |
242 | if (!acceptable) |
243 | return false; |
244 | } |
245 | return true; |
246 | }; |
247 | if (!checkSemanticTypes(inputItems)) |
248 | return false; |
249 | if (!checkSemanticTypes(outputItems)) |
250 | return false; |
251 | } |
252 | |
253 | { |
254 | auto checkSemanticForShaderType = [&](const Vector<EntryPointItem>& items, AST::BaseSemantic::ShaderItemDirection direction) -> bool { |
255 | for (auto& item : items) { |
256 | auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool { |
257 | return semantic.isAcceptableForShaderItemDirection(direction, entryPointType); |
258 | }), *item.semantic); |
259 | if (!acceptable) |
260 | return false; |
261 | } |
262 | return true; |
263 | }; |
264 | if (!checkSemanticForShaderType(inputItems, AST::BaseSemantic::ShaderItemDirection::Input)) |
265 | return false; |
266 | if (!checkSemanticForShaderType(outputItems, AST::BaseSemantic::ShaderItemDirection::Output)) |
267 | return false; |
268 | } |
269 | |
270 | { |
271 | auto checkPODData = [&](const Vector<EntryPointItem>& items) -> bool { |
272 | for (auto& item : items) { |
273 | PODChecker podChecker; |
274 | if (is<AST::PointerType>(item.unnamedType)) |
275 | podChecker.checkErrorAndVisit(downcast<AST::PointerType>(*item.unnamedType).elementType()); |
276 | else if (is<AST::ArrayReferenceType>(item.unnamedType)) |
277 | podChecker.checkErrorAndVisit(downcast<AST::ArrayReferenceType>(*item.unnamedType).elementType()); |
278 | else if (is<AST::ArrayType>(item.unnamedType)) |
279 | podChecker.checkErrorAndVisit(downcast<AST::ArrayType>(*item.unnamedType).type()); |
280 | else |
281 | continue; |
282 | if (podChecker.error()) |
283 | return false; |
284 | } |
285 | return true; |
286 | }; |
287 | if (!checkPODData(inputItems)) |
288 | return false; |
289 | if (!checkPODData(outputItems)) |
290 | return false; |
291 | } |
292 | |
293 | return true; |
294 | } |
295 | |
296 | static bool checkOperatorOverload(const AST::FunctionDefinition& functionDefinition, const Intrinsics& intrinsics, NameContext& nameContext) |
297 | { |
298 | enum class CheckKind { |
299 | Index, |
300 | Dot |
301 | }; |
302 | |
303 | auto checkGetter = [&](CheckKind kind) -> bool { |
304 | size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1; |
305 | if (functionDefinition.parameters().size() != numExpectedParameters) |
306 | return false; |
307 | auto& firstParameterUnifyNode = (*functionDefinition.parameters()[0].type())->unifyNode(); |
308 | if (is<AST::UnnamedType>(firstParameterUnifyNode)) { |
309 | auto& unnamedType = downcast<AST::UnnamedType>(firstParameterUnifyNode); |
310 | if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType)) |
311 | return false; |
312 | } |
313 | if (kind == CheckKind::Index) { |
314 | auto& secondParameterUnifyNode = (*functionDefinition.parameters()[1].type())->unifyNode(); |
315 | if (!is<AST::NamedType>(secondParameterUnifyNode)) |
316 | return false; |
317 | auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode); |
318 | if (!is<AST::NativeTypeDeclaration>(namedType)) |
319 | return false; |
320 | auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType); |
321 | if (!nativeTypeDeclaration.isInt()) |
322 | return false; |
323 | } |
324 | return true; |
325 | }; |
326 | |
327 | auto checkSetter = [&](CheckKind kind) -> bool { |
328 | size_t numExpectedParameters = kind == CheckKind::Index ? 3 : 2; |
329 | if (functionDefinition.parameters().size() != numExpectedParameters) |
330 | return false; |
331 | auto& firstArgumentUnifyNode = (*functionDefinition.parameters()[0].type())->unifyNode(); |
332 | if (is<AST::UnnamedType>(firstArgumentUnifyNode)) { |
333 | auto& unnamedType = downcast<AST::UnnamedType>(firstArgumentUnifyNode); |
334 | if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType)) |
335 | return false; |
336 | } |
337 | if (kind == CheckKind::Index) { |
338 | auto& secondParameterUnifyNode = (*functionDefinition.parameters()[1].type())->unifyNode(); |
339 | if (!is<AST::NamedType>(secondParameterUnifyNode)) |
340 | return false; |
341 | auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode); |
342 | if (!is<AST::NativeTypeDeclaration>(namedType)) |
343 | return false; |
344 | auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType); |
345 | if (!nativeTypeDeclaration.isInt()) |
346 | return false; |
347 | } |
348 | if (!matches(functionDefinition.type(), *functionDefinition.parameters()[0].type())) |
349 | return false; |
350 | auto& valueType = *functionDefinition.parameters()[numExpectedParameters - 1].type(); |
351 | auto getterName = functionDefinition.name().substring(0, functionDefinition.name().length() - 1); |
352 | auto* getterFuncs = nameContext.getFunctions(getterName); |
353 | if (!getterFuncs) |
354 | return false; |
355 | Vector<ResolvingType> argumentTypes; |
356 | Vector<std::reference_wrapper<ResolvingType>> argumentTypeReferences; |
357 | for (size_t i = 0; i < numExpectedParameters - 1; ++i) |
358 | argumentTypes.append((*functionDefinition.parameters()[0].type())->clone()); |
359 | for (auto& argumentType : argumentTypes) |
360 | argumentTypeReferences.append(argumentType); |
361 | auto* overload = resolveFunctionOverloadImpl(*getterFuncs, argumentTypeReferences, nullptr); |
362 | if (!overload) |
363 | return false; |
364 | auto& resultType = overload->type(); |
365 | return matches(resultType, valueType); |
366 | }; |
367 | |
368 | auto checkAnder = [&](CheckKind kind) -> bool { |
369 | size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1; |
370 | if (functionDefinition.parameters().size() != numExpectedParameters) |
371 | return false; |
372 | { |
373 | auto& unifyNode = functionDefinition.type().unifyNode(); |
374 | if (!is<AST::UnnamedType>(unifyNode)) |
375 | return false; |
376 | auto& unnamedType = downcast<AST::UnnamedType>(unifyNode); |
377 | if (!is<AST::PointerType>(unnamedType)) |
378 | return false; |
379 | } |
380 | { |
381 | auto& unifyNode = (*functionDefinition.parameters()[0].type())->unifyNode(); |
382 | if (!is<AST::UnnamedType>(unifyNode)) |
383 | return false; |
384 | auto& unnamedType = downcast<AST::UnnamedType>(unifyNode); |
385 | return is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType); |
386 | } |
387 | }; |
388 | |
389 | if (!functionDefinition.isOperator()) |
390 | return true; |
391 | if (functionDefinition.isCast()) |
392 | return true; |
393 | if (functionDefinition.name() == "operator++" || functionDefinition.name() == "operator--" ) { |
394 | return functionDefinition.parameters().size() == 1 |
395 | && matches(*functionDefinition.parameters()[0].type(), functionDefinition.type()); |
396 | } |
397 | if (functionDefinition.name() == "operator+" || functionDefinition.name() == "operator-" ) |
398 | return functionDefinition.parameters().size() == 1 || functionDefinition.parameters().size() == 2; |
399 | if (functionDefinition.name() == "operator*" |
400 | || functionDefinition.name() == "operator/" |
401 | || functionDefinition.name() == "operator%" |
402 | || functionDefinition.name() == "operator&" |
403 | || functionDefinition.name() == "operator|" |
404 | || functionDefinition.name() == "operator^" |
405 | || functionDefinition.name() == "operator<<" |
406 | || functionDefinition.name() == "opreator>>" ) |
407 | return functionDefinition.parameters().size() == 2; |
408 | if (functionDefinition.name() == "operator~" ) |
409 | return functionDefinition.parameters().size() == 1; |
410 | if (functionDefinition.name() == "operator==" |
411 | || functionDefinition.name() == "operator<" |
412 | || functionDefinition.name() == "operator<=" |
413 | || functionDefinition.name() == "operator>" |
414 | || functionDefinition.name() == "operator>=" ) { |
415 | return functionDefinition.parameters().size() == 2 |
416 | && matches(functionDefinition.type(), intrinsics.boolType()); |
417 | } |
418 | if (functionDefinition.name() == "operator[]" ) |
419 | return checkGetter(CheckKind::Index); |
420 | if (functionDefinition.name() == "operator[]=" ) |
421 | return checkSetter(CheckKind::Index); |
422 | if (functionDefinition.name() == "operator&[]" ) |
423 | return checkAnder(CheckKind::Index); |
424 | if (functionDefinition.name().startsWith("operator." )) { |
425 | if (functionDefinition.name().endsWith("=" )) |
426 | return checkSetter(CheckKind::Dot); |
427 | return checkGetter(CheckKind::Dot); |
428 | } |
429 | if (functionDefinition.name().startsWith("operator&." )) |
430 | return checkAnder(CheckKind::Dot); |
431 | return false; |
432 | } |
433 | |
434 | class Checker : public Visitor { |
435 | public: |
436 | Checker(const Intrinsics& intrinsics, Program& program) |
437 | : m_intrinsics(intrinsics) |
438 | , m_program(program) |
439 | { |
440 | } |
441 | |
442 | ~Checker() = default; |
443 | |
444 | void visit(Program&) override; |
445 | |
446 | bool assignTypes(); |
447 | |
448 | private: |
449 | bool checkShaderType(const AST::FunctionDefinition&); |
450 | void finishVisitingPropertyAccess(AST::PropertyAccessExpression&, AST::UnnamedType& wrappedBaseType, AST::UnnamedType* extraArgumentType = nullptr); |
451 | bool isBoolType(ResolvingType&); |
452 | struct RecurseInfo { |
453 | ResolvingType& resolvingType; |
454 | Optional<AST::AddressSpace>& addressSpace; |
455 | }; |
456 | Optional<RecurseInfo> recurseAndGetInfo(AST::Expression&, bool requiresLValue = false); |
457 | Optional<RecurseInfo> getInfo(AST::Expression&, bool requiresLValue = false); |
458 | Optional<UniqueRef<AST::UnnamedType>> recurseAndWrapBaseType(AST::PropertyAccessExpression&); |
459 | bool recurseAndRequireBoolType(AST::Expression&); |
460 | void assignType(AST::Expression&, UniqueRef<AST::UnnamedType>&&, Optional<AST::AddressSpace> = WTF::nullopt); |
461 | void assignType(AST::Expression&, RefPtr<ResolvableTypeReference>&&, Optional<AST::AddressSpace> = WTF::nullopt); |
462 | void forwardType(AST::Expression&, ResolvingType&, Optional<AST::AddressSpace> = WTF::nullopt); |
463 | |
464 | void visit(AST::FunctionDefinition&) override; |
465 | void visit(AST::EnumerationDefinition&) override; |
466 | void visit(AST::TypeReference&) override; |
467 | void visit(AST::VariableDeclaration&) override; |
468 | void visit(AST::AssignmentExpression&) override; |
469 | void visit(AST::ReadModifyWriteExpression&) override; |
470 | void visit(AST::DereferenceExpression&) override; |
471 | void visit(AST::MakePointerExpression&) override; |
472 | void visit(AST::MakeArrayReferenceExpression&) override; |
473 | void visit(AST::DotExpression&) override; |
474 | void visit(AST::IndexExpression&) override; |
475 | void visit(AST::VariableReference&) override; |
476 | void visit(AST::Return&) override; |
477 | void visit(AST::PointerType&) override; |
478 | void visit(AST::ArrayReferenceType&) override; |
479 | void visit(AST::IntegerLiteral&) override; |
480 | void visit(AST::UnsignedIntegerLiteral&) override; |
481 | void visit(AST::FloatLiteral&) override; |
482 | void visit(AST::NullLiteral&) override; |
483 | void visit(AST::BooleanLiteral&) override; |
484 | void visit(AST::EnumerationMemberLiteral&) override; |
485 | void visit(AST::LogicalNotExpression&) override; |
486 | void visit(AST::LogicalExpression&) override; |
487 | void visit(AST::IfStatement&) override; |
488 | void visit(AST::WhileLoop&) override; |
489 | void visit(AST::DoWhileLoop&) override; |
490 | void visit(AST::ForLoop&) override; |
491 | void visit(AST::SwitchStatement&) override; |
492 | void visit(AST::CommaExpression&) override; |
493 | void visit(AST::TernaryExpression&) override; |
494 | void visit(AST::CallExpression&) override; |
495 | |
496 | HashMap<AST::Expression*, ResolvingType> m_typeMap; |
497 | HashMap<AST::Expression*, Optional<AST::AddressSpace>> m_addressSpaceMap; |
498 | HashSet<String> m_vertexEntryPoints; |
499 | HashSet<String> m_fragmentEntryPoints; |
500 | HashSet<String> m_computeEntryPoints; |
501 | const Intrinsics& m_intrinsics; |
502 | Program& m_program; |
503 | }; |
504 | |
505 | void Checker::visit(Program& program) |
506 | { |
507 | // These visiting functions might add new global statements, so don't use foreach syntax. |
508 | for (size_t i = 0; i < program.typeDefinitions().size(); ++i) |
509 | checkErrorAndVisit(program.typeDefinitions()[i]); |
510 | for (size_t i = 0; i < program.structureDefinitions().size(); ++i) |
511 | checkErrorAndVisit(program.structureDefinitions()[i]); |
512 | for (size_t i = 0; i < program.enumerationDefinitions().size(); ++i) |
513 | checkErrorAndVisit(program.enumerationDefinitions()[i]); |
514 | for (size_t i = 0; i < program.nativeTypeDeclarations().size(); ++i) |
515 | checkErrorAndVisit(program.nativeTypeDeclarations()[i]); |
516 | |
517 | for (size_t i = 0; i < program.functionDefinitions().size(); ++i) |
518 | checkErrorAndVisit(program.functionDefinitions()[i]); |
519 | for (size_t i = 0; i < program.nativeFunctionDeclarations().size(); ++i) |
520 | checkErrorAndVisit(program.nativeFunctionDeclarations()[i]); |
521 | } |
522 | |
523 | bool Checker::assignTypes() |
524 | { |
525 | for (auto& keyValuePair : m_typeMap) { |
526 | auto success = keyValuePair.value.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool { |
527 | keyValuePair.key->setType(unnamedType->clone()); |
528 | return true; |
529 | }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool { |
530 | if (!resolvableTypeReference->resolvableType().resolvedType()) { |
531 | if (!static_cast<bool>(commit(resolvableTypeReference->resolvableType()))) |
532 | return false; |
533 | } |
534 | keyValuePair.key->setType(resolvableTypeReference->resolvableType().resolvedType()->clone()); |
535 | return true; |
536 | })); |
537 | if (!success) |
538 | return false; |
539 | } |
540 | |
541 | for (auto& keyValuePair : m_addressSpaceMap) |
542 | keyValuePair.key->setAddressSpace(keyValuePair.value); |
543 | return true; |
544 | } |
545 | |
546 | bool Checker::checkShaderType(const AST::FunctionDefinition& functionDefinition) |
547 | { |
548 | switch (*functionDefinition.entryPointType()) { |
549 | case AST::EntryPointType::Vertex: |
550 | return static_cast<bool>(m_vertexEntryPoints.add(functionDefinition.name())); |
551 | case AST::EntryPointType::Fragment: |
552 | return static_cast<bool>(m_fragmentEntryPoints.add(functionDefinition.name())); |
553 | case AST::EntryPointType::Compute: |
554 | return static_cast<bool>(m_computeEntryPoints.add(functionDefinition.name())); |
555 | } |
556 | } |
557 | |
558 | void Checker::visit(AST::FunctionDefinition& functionDefinition) |
559 | { |
560 | if (functionDefinition.entryPointType()) { |
561 | if (!checkShaderType(functionDefinition)) { |
562 | setError(); |
563 | return; |
564 | } |
565 | auto entryPointItems = gatherEntryPointItems(m_intrinsics, functionDefinition); |
566 | if (!entryPointItems) { |
567 | setError(); |
568 | return; |
569 | } |
570 | if (!checkSemantics(entryPointItems->inputs, entryPointItems->outputs, functionDefinition.entryPointType(), m_intrinsics)) { |
571 | setError(); |
572 | return; |
573 | } |
574 | } |
575 | if (!checkOperatorOverload(functionDefinition, m_intrinsics, m_program.nameContext())) { |
576 | setError(); |
577 | return; |
578 | } |
579 | |
580 | Visitor::visit(functionDefinition); |
581 | } |
582 | |
583 | static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& left, ResolvingType& right) |
584 | { |
585 | return left.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> Optional<UniqueRef<AST::UnnamedType>> { |
586 | return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> { |
587 | if (matches(left, right)) |
588 | return left->clone(); |
589 | return WTF::nullopt; |
590 | }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> { |
591 | return matchAndCommit(left, right->resolvableType()); |
592 | })); |
593 | }, [&](RefPtr<ResolvableTypeReference>& left) -> Optional<UniqueRef<AST::UnnamedType>> { |
594 | return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> { |
595 | return matchAndCommit(right, left->resolvableType()); |
596 | }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> { |
597 | return matchAndCommit(left->resolvableType(), right->resolvableType()); |
598 | })); |
599 | })); |
600 | } |
601 | |
602 | static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::UnnamedType& unnamedType) |
603 | { |
604 | return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> { |
605 | if (matches(unnamedType, resolvingType)) |
606 | return unnamedType.clone(); |
607 | return WTF::nullopt; |
608 | }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> { |
609 | return matchAndCommit(unnamedType, resolvingType->resolvableType()); |
610 | })); |
611 | } |
612 | |
613 | static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::NamedType& namedType) |
614 | { |
615 | return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> { |
616 | if (matches(resolvingType, namedType)) |
617 | return resolvingType->clone(); |
618 | return WTF::nullopt; |
619 | }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> { |
620 | return matchAndCommit(namedType, resolvingType->resolvableType()); |
621 | })); |
622 | } |
623 | |
624 | void Checker::visit(AST::EnumerationDefinition& enumerationDefinition) |
625 | { |
626 | auto* baseType = ([&]() -> AST::NativeTypeDeclaration* { |
627 | checkErrorAndVisit(enumerationDefinition.type()); |
628 | auto& baseType = enumerationDefinition.type().unifyNode(); |
629 | if (!is<AST::NamedType>(baseType)) |
630 | return nullptr; |
631 | auto& namedType = downcast<AST::NamedType>(baseType); |
632 | if (!is<AST::NativeTypeDeclaration>(namedType)) |
633 | return nullptr; |
634 | auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType); |
635 | if (!nativeTypeDeclaration.isInt()) |
636 | return nullptr; |
637 | return &nativeTypeDeclaration; |
638 | })(); |
639 | if (!baseType) { |
640 | setError(); |
641 | return; |
642 | } |
643 | |
644 | auto enumerationMembers = enumerationDefinition.enumerationMembers(); |
645 | |
646 | for (auto& member : enumerationMembers) { |
647 | if (!member.get().value()) |
648 | continue; |
649 | |
650 | bool success = false; |
651 | member.get().value()->visit(WTF::makeVisitor([&](AST::Expression& value) { |
652 | auto valueInfo = recurseAndGetInfo(value); |
653 | if (!valueInfo) |
654 | return; |
655 | success = static_cast<bool>(matchAndCommit(valueInfo->resolvingType, *baseType)); |
656 | })); |
657 | if (!success) { |
658 | setError(); |
659 | return; |
660 | } |
661 | } |
662 | |
663 | int64_t nextValue = 0; |
664 | for (auto& member : enumerationMembers) { |
665 | if (member.get().value()) { |
666 | int64_t value; |
667 | member.get().value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) { |
668 | value = integerLiteral.valueForSelectedType(); |
669 | }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) { |
670 | value = unsignedIntegerLiteral.valueForSelectedType(); |
671 | }, [&](auto&) { |
672 | ASSERT_NOT_REACHED(); |
673 | })); |
674 | nextValue = baseType->successor()(value); |
675 | } else { |
676 | if (nextValue > std::numeric_limits<int>::max()) { |
677 | ASSERT(nextValue <= std::numeric_limits<unsigned>::max()); |
678 | member.get().setValue(AST::ConstantExpression(AST::UnsignedIntegerLiteral(Lexer::Token(member.get().origin()), static_cast<unsigned>(nextValue)))); |
679 | } |
680 | ASSERT(nextValue >= std::numeric_limits<int>::min()); |
681 | member.get().setValue(AST::ConstantExpression(AST::IntegerLiteral(Lexer::Token(member.get().origin()), static_cast<int>(nextValue)))); |
682 | nextValue = baseType->successor()(nextValue); |
683 | } |
684 | } |
685 | |
686 | auto getValue = [&](AST::EnumerationMember& member) -> int64_t { |
687 | int64_t value; |
688 | ASSERT(member.value()); |
689 | member.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) { |
690 | value = integerLiteral.value(); |
691 | }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) { |
692 | value = unsignedIntegerLiteral.value(); |
693 | }, [&](auto&) { |
694 | ASSERT_NOT_REACHED(); |
695 | })); |
696 | return value; |
697 | }; |
698 | |
699 | for (size_t i = 0; i < enumerationMembers.size(); ++i) { |
700 | auto value = getValue(enumerationMembers[i].get()); |
701 | for (size_t j = i + 1; j < enumerationMembers.size(); ++j) { |
702 | auto otherValue = getValue(enumerationMembers[j].get()); |
703 | if (value == otherValue) { |
704 | setError(); |
705 | return; |
706 | } |
707 | } |
708 | } |
709 | |
710 | bool foundZero = false; |
711 | for (auto& member : enumerationMembers) { |
712 | if (!getValue(member.get())) { |
713 | foundZero = true; |
714 | break; |
715 | } |
716 | } |
717 | if (!foundZero) { |
718 | setError(); |
719 | return; |
720 | } |
721 | } |
722 | |
723 | void Checker::visit(AST::TypeReference& typeReference) |
724 | { |
725 | ASSERT(typeReference.resolvedType()); |
726 | |
727 | for (auto& typeArgument : typeReference.typeArguments()) |
728 | checkErrorAndVisit(typeArgument); |
729 | } |
730 | |
731 | auto Checker::recurseAndGetInfo(AST::Expression& expression, bool requiresLValue) -> Optional<RecurseInfo> |
732 | { |
733 | Visitor::visit(expression); |
734 | if (error()) |
735 | return WTF::nullopt; |
736 | return getInfo(expression, requiresLValue); |
737 | } |
738 | |
739 | auto Checker::getInfo(AST::Expression& expression, bool requiresLValue) -> Optional<RecurseInfo> |
740 | { |
741 | auto typeIterator = m_typeMap.find(&expression); |
742 | ASSERT(typeIterator != m_typeMap.end()); |
743 | |
744 | auto addressSpaceIterator = m_addressSpaceMap.find(&expression); |
745 | ASSERT(addressSpaceIterator != m_addressSpaceMap.end()); |
746 | if (requiresLValue && !addressSpaceIterator->value) { |
747 | setError(); |
748 | return WTF::nullopt; |
749 | } |
750 | return {{ typeIterator->value, addressSpaceIterator->value }}; |
751 | } |
752 | |
753 | void Checker::visit(AST::VariableDeclaration& variableDeclaration) |
754 | { |
755 | // ReadModifyWriteExpressions are the only place where anonymous variables exist, |
756 | // and that doesn't recurse on the anonymous variables, so we can assume the variable has a type. |
757 | checkErrorAndVisit(*variableDeclaration.type()); |
758 | if (variableDeclaration.initializer()) { |
759 | auto& lhsType = *variableDeclaration.type(); |
760 | auto initializerInfo = recurseAndGetInfo(*variableDeclaration.initializer()); |
761 | if (!initializerInfo) |
762 | return; |
763 | if (!matchAndCommit(initializerInfo->resolvingType, lhsType)) { |
764 | setError(); |
765 | return; |
766 | } |
767 | } |
768 | } |
769 | |
770 | void Checker::assignType(AST::Expression& expression, UniqueRef<AST::UnnamedType>&& unnamedType, Optional<AST::AddressSpace> addressSpace) |
771 | { |
772 | auto addResult = m_typeMap.add(&expression, WTFMove(unnamedType)); |
773 | ASSERT_UNUSED(addResult, addResult.isNewEntry); |
774 | auto addressSpaceAddResult = m_addressSpaceMap.add(&expression, addressSpace); |
775 | ASSERT_UNUSED(addressSpaceAddResult, addressSpaceAddResult.isNewEntry); |
776 | } |
777 | |
778 | void Checker::assignType(AST::Expression& expression, RefPtr<ResolvableTypeReference>&& resolvableTypeReference, Optional<AST::AddressSpace> addressSpace) |
779 | { |
780 | auto addResult = m_typeMap.add(&expression, WTFMove(resolvableTypeReference)); |
781 | ASSERT_UNUSED(addResult, addResult.isNewEntry); |
782 | auto addressSpaceAddResult = m_addressSpaceMap.add(&expression, addressSpace); |
783 | ASSERT_UNUSED(addressSpaceAddResult, addressSpaceAddResult.isNewEntry); |
784 | } |
785 | |
786 | void Checker::visit(AST::AssignmentExpression& assignmentExpression) |
787 | { |
788 | auto leftInfo = recurseAndGetInfo(assignmentExpression.left(), true); |
789 | if (!leftInfo) |
790 | return; |
791 | |
792 | auto rightInfo = recurseAndGetInfo(assignmentExpression.right()); |
793 | if (!rightInfo) |
794 | return; |
795 | |
796 | auto resultType = matchAndCommit(leftInfo->resolvingType, rightInfo->resolvingType); |
797 | if (!resultType) { |
798 | setError(); |
799 | return; |
800 | } |
801 | |
802 | assignType(assignmentExpression, WTFMove(*resultType)); |
803 | } |
804 | |
805 | void Checker::forwardType(AST::Expression& expression, ResolvingType& resolvingType, Optional<AST::AddressSpace> addressSpace) |
806 | { |
807 | resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& result) { |
808 | auto addResult = m_typeMap.add(&expression, result->clone()); |
809 | ASSERT_UNUSED(addResult, addResult.isNewEntry); |
810 | }, [&](RefPtr<ResolvableTypeReference>& result) { |
811 | auto addResult = m_typeMap.add(&expression, result.copyRef()); |
812 | ASSERT_UNUSED(addResult, addResult.isNewEntry); |
813 | })); |
814 | auto addressSpaceAddResult = m_addressSpaceMap.add(&expression, addressSpace); |
815 | ASSERT_UNUSED(addressSpaceAddResult, addressSpaceAddResult.isNewEntry); |
816 | } |
817 | |
818 | void Checker::visit(AST::ReadModifyWriteExpression& readModifyWriteExpression) |
819 | { |
820 | auto lValueInfo = recurseAndGetInfo(readModifyWriteExpression.lValue(), true); |
821 | if (!lValueInfo) |
822 | return; |
823 | |
824 | // FIXME: Figure out what to do with the ReadModifyWriteExpression's AnonymousVariables. |
825 | |
826 | auto newValueInfo = recurseAndGetInfo(*readModifyWriteExpression.newValueExpression()); |
827 | if (!newValueInfo) |
828 | return; |
829 | |
830 | if (!matchAndCommit(lValueInfo->resolvingType, newValueInfo->resolvingType)) { |
831 | setError(); |
832 | return; |
833 | } |
834 | |
835 | auto resultInfo = recurseAndGetInfo(*readModifyWriteExpression.resultExpression()); |
836 | if (!resultInfo) |
837 | return; |
838 | |
839 | forwardType(readModifyWriteExpression, resultInfo->resolvingType); |
840 | } |
841 | |
842 | static AST::UnnamedType* getUnnamedType(ResolvingType& resolvingType) |
843 | { |
844 | return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& type) -> AST::UnnamedType* { |
845 | return &type; |
846 | }, [](RefPtr<ResolvableTypeReference>& type) -> AST::UnnamedType* { |
847 | // FIXME: If the type isn't committed, should we just commit() it now? |
848 | return type->resolvableType().resolvedType(); |
849 | })); |
850 | } |
851 | |
852 | void Checker::visit(AST::DereferenceExpression& dereferenceExpression) |
853 | { |
854 | auto pointerInfo = recurseAndGetInfo(dereferenceExpression.pointer()); |
855 | if (!pointerInfo) |
856 | return; |
857 | |
858 | auto* unnamedType = getUnnamedType(pointerInfo->resolvingType); |
859 | |
860 | auto* pointerType = ([&](AST::UnnamedType* unnamedType) -> AST::PointerType* { |
861 | if (!unnamedType) |
862 | return nullptr; |
863 | auto& unifyNode = unnamedType->unifyNode(); |
864 | if (!is<AST::UnnamedType>(unifyNode)) |
865 | return nullptr; |
866 | auto& unnamedUnifyType = downcast<AST::UnnamedType>(unifyNode); |
867 | if (!is<AST::PointerType>(unnamedUnifyType)) |
868 | return nullptr; |
869 | return &downcast<AST::PointerType>(unnamedUnifyType); |
870 | })(unnamedType); |
871 | if (!pointerType) { |
872 | setError(); |
873 | return; |
874 | } |
875 | |
876 | assignType(dereferenceExpression, pointerType->clone(), pointerType->addressSpace()); |
877 | } |
878 | |
879 | void Checker::visit(AST::MakePointerExpression& makePointerExpression) |
880 | { |
881 | auto lValueInfo = recurseAndGetInfo(makePointerExpression.lValue(), true); |
882 | if (!lValueInfo) |
883 | return; |
884 | |
885 | auto* lValueType = getUnnamedType(lValueInfo->resolvingType); |
886 | if (!lValueType) { |
887 | setError(); |
888 | return; |
889 | } |
890 | |
891 | assignType(makePointerExpression, makeUniqueRef<AST::PointerType>(Lexer::Token(makePointerExpression.origin()), *lValueInfo->addressSpace, lValueType->clone())); |
892 | } |
893 | |
894 | void Checker::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression) |
895 | { |
896 | auto lValueInfo = recurseAndGetInfo(makeArrayReferenceExpression.lValue()); |
897 | if (!lValueInfo) |
898 | return; |
899 | |
900 | auto* lValueType = getUnnamedType(lValueInfo->resolvingType); |
901 | if (!lValueType) { |
902 | setError(); |
903 | return; |
904 | } |
905 | |
906 | auto& unifyNode = lValueType->unifyNode(); |
907 | if (is<AST::UnnamedType>(unifyNode)) { |
908 | auto& unnamedType = downcast<AST::UnnamedType>(unifyNode); |
909 | if (is<AST::PointerType>(unnamedType)) { |
910 | auto& pointerType = downcast<AST::PointerType>(unnamedType); |
911 | // FIXME: Save the fact that we're not targetting the item; we're targetting the item's inner element. |
912 | assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), pointerType.addressSpace(), pointerType.elementType().clone())); |
913 | return; |
914 | } |
915 | |
916 | if (!lValueInfo->addressSpace) { |
917 | setError(); |
918 | return; |
919 | } |
920 | |
921 | if (is<AST::ArrayType>(unnamedType)) { |
922 | auto& arrayType = downcast<AST::ArrayType>(unnamedType); |
923 | // FIXME: Save the number of elements. |
924 | assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), *lValueInfo->addressSpace, arrayType.type().clone())); |
925 | return; |
926 | } |
927 | } |
928 | |
929 | if (!lValueInfo->addressSpace) { |
930 | setError(); |
931 | return; |
932 | } |
933 | |
934 | assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), *lValueInfo->addressSpace, lValueType->clone())); |
935 | } |
936 | |
937 | void Checker::finishVisitingPropertyAccess(AST::PropertyAccessExpression& propertyAccessExpression, AST::UnnamedType& wrappedBaseType, AST::UnnamedType* extraArgumentType) |
938 | { |
939 | using OverloadResolution = std::tuple<AST::FunctionDeclaration*, AST::UnnamedType*>; |
940 | |
941 | AST::FunctionDeclaration* getFunction; |
942 | AST::UnnamedType* getReturnType; |
943 | std::tie(getFunction, getReturnType) = ([&]() -> OverloadResolution { |
944 | ResolvingType getArgumentType1(wrappedBaseType.clone()); |
945 | Optional<ResolvingType> getArgumentType2; |
946 | if (extraArgumentType) |
947 | getArgumentType2 = ResolvingType(extraArgumentType->clone()); |
948 | |
949 | Vector<std::reference_wrapper<ResolvingType>> getArgumentTypes; |
950 | getArgumentTypes.append(getArgumentType1); |
951 | if (getArgumentType2) |
952 | getArgumentTypes.append(*getArgumentType2); |
953 | |
954 | auto* getFunction = resolveFunctionOverloadImpl(propertyAccessExpression.possibleGetOverloads(), getArgumentTypes, nullptr); |
955 | if (!getFunction) |
956 | return std::make_pair(nullptr, nullptr); |
957 | return std::make_pair(getFunction, &getFunction->type()); |
958 | })(); |
959 | |
960 | AST::FunctionDeclaration* andFunction; |
961 | AST::UnnamedType* andReturnType; |
962 | std::tie(andFunction, andReturnType) = ([&]() -> OverloadResolution { |
963 | auto computeAndArgumentType = [&](AST::UnnamedType& unnamedType) -> Optional<ResolvingType> { |
964 | if (is<AST::ArrayReferenceType>(unnamedType)) |
965 | return { unnamedType.clone() }; |
966 | if (is<AST::ArrayType>(unnamedType)) |
967 | return { ResolvingType(makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(propertyAccessExpression.origin()), AST::AddressSpace::Thread, downcast<AST::ArrayType>(unnamedType).type().clone())) }; |
968 | if (is<AST::PointerType>(unnamedType)) |
969 | return WTF::nullopt; |
970 | return { ResolvingType(makeUniqueRef<AST::PointerType>(Lexer::Token(propertyAccessExpression.origin()), AST::AddressSpace::Thread, downcast<AST::TypeReference>(unnamedType).clone())) }; |
971 | }; |
972 | auto computeAndReturnType = [&](AST::UnnamedType& unnamedType) -> AST::UnnamedType* { |
973 | if (is<AST::PointerType>(unnamedType)) |
974 | return &downcast<AST::PointerType>(unnamedType).elementType(); |
975 | return nullptr; |
976 | }; |
977 | |
978 | auto andArgumentType1 = computeAndArgumentType(wrappedBaseType); |
979 | if (!andArgumentType1) |
980 | return std::make_pair(nullptr, nullptr); |
981 | Optional<ResolvingType> andArgumentType2; |
982 | if (extraArgumentType) |
983 | andArgumentType2 = ResolvingType(extraArgumentType->clone()); |
984 | |
985 | Vector<std::reference_wrapper<ResolvingType>> andArgumentTypes; |
986 | andArgumentTypes.append(*andArgumentType1); |
987 | if (andArgumentType2) |
988 | andArgumentTypes.append(*andArgumentType2); |
989 | |
990 | auto* andFunction = resolveFunctionOverloadImpl(propertyAccessExpression.possibleAndOverloads(), andArgumentTypes, nullptr); |
991 | if (!andFunction) |
992 | return std::make_pair(nullptr, nullptr); |
993 | return std::make_pair(andFunction, computeAndReturnType(andFunction->type())); |
994 | })(); |
995 | |
996 | if (!getReturnType && !andReturnType) { |
997 | setError(); |
998 | return; |
999 | } |
1000 | |
1001 | if (getReturnType && andReturnType && !matches(*getReturnType, *andReturnType)) { |
1002 | setError(); |
1003 | return; |
1004 | } |
1005 | |
1006 | AST::FunctionDeclaration* setFunction; |
1007 | AST::UnnamedType* setReturnType; |
1008 | std::tie(setFunction, setReturnType) = ([&]() -> OverloadResolution { |
1009 | ResolvingType setArgument1Type(wrappedBaseType.clone()); |
1010 | Optional<ResolvingType> setArgumentType2; |
1011 | if (extraArgumentType) |
1012 | setArgumentType2 = ResolvingType(extraArgumentType->clone()); |
1013 | ResolvingType setArgument3Type(getReturnType ? getReturnType->clone() : andReturnType->clone()); |
1014 | |
1015 | Vector<std::reference_wrapper<ResolvingType>> setArgumentTypes; |
1016 | setArgumentTypes.append(setArgument1Type); |
1017 | if (setArgumentType2) |
1018 | setArgumentTypes.append(*setArgumentType2); |
1019 | setArgumentTypes.append(setArgument3Type); |
1020 | |
1021 | auto* setFunction = resolveFunctionOverloadImpl(propertyAccessExpression.possibleSetOverloads(), setArgumentTypes, nullptr); |
1022 | if (!setFunction) |
1023 | return std::make_pair(nullptr, nullptr); |
1024 | return std::make_pair(setFunction, &setFunction->type()); |
1025 | })(); |
1026 | |
1027 | if (setFunction) { |
1028 | if (!matches(setFunction->type(), wrappedBaseType)) { |
1029 | setError(); |
1030 | return; |
1031 | } |
1032 | } |
1033 | |
1034 | Optional<AST::AddressSpace> addressSpace; |
1035 | if (getReturnType || andReturnType) { |
1036 | // FIXME: The reference compiler has "else if (!node.base.isLValue && !baseType.isArrayRef)", |
1037 | // but I don't understand why it exists. I haven't written it here, and I'll investigate |
1038 | // if we can remove it from the reference compiler. |
1039 | if (is<AST::ReferenceType>(wrappedBaseType)) |
1040 | addressSpace = downcast<AST::ReferenceType>(wrappedBaseType).addressSpace(); |
1041 | else { |
1042 | auto addressSpaceIterator = m_addressSpaceMap.find(&propertyAccessExpression.base()); |
1043 | ASSERT(addressSpaceIterator != m_addressSpaceMap.end()); |
1044 | if (addressSpaceIterator->value) |
1045 | addressSpace = *addressSpaceIterator->value; |
1046 | else { |
1047 | setError(); |
1048 | return; |
1049 | } |
1050 | } |
1051 | } |
1052 | |
1053 | // FIXME: Generate the call expressions |
1054 | |
1055 | assignType(propertyAccessExpression, getReturnType ? getReturnType->clone() : andReturnType->clone(), addressSpace); |
1056 | } |
1057 | |
1058 | Optional<UniqueRef<AST::UnnamedType>> Checker::recurseAndWrapBaseType(AST::PropertyAccessExpression& propertyAccessExpression) |
1059 | { |
1060 | auto baseInfo = recurseAndGetInfo(propertyAccessExpression.base()); |
1061 | if (!baseInfo) |
1062 | return WTF::nullopt; |
1063 | |
1064 | auto* baseType = getUnnamedType(baseInfo->resolvingType); |
1065 | if (!baseType) { |
1066 | setError(); |
1067 | return WTF::nullopt; |
1068 | } |
1069 | auto& baseUnifyNode = baseType->unifyNode(); |
1070 | if (is<AST::UnnamedType>(baseUnifyNode)) |
1071 | return downcast<AST::UnnamedType>(baseUnifyNode).clone(); |
1072 | ASSERT(is<AST::NamedType>(baseUnifyNode)); |
1073 | return { AST::TypeReference::wrap(Lexer::Token(propertyAccessExpression.origin()), downcast<AST::NamedType>(baseUnifyNode)) }; |
1074 | } |
1075 | |
1076 | void Checker::visit(AST::DotExpression& dotExpression) |
1077 | { |
1078 | auto baseType = recurseAndWrapBaseType(dotExpression); |
1079 | if (!baseType) |
1080 | return; |
1081 | |
1082 | finishVisitingPropertyAccess(dotExpression, *baseType); |
1083 | } |
1084 | |
1085 | void Checker::visit(AST::IndexExpression& indexExpression) |
1086 | { |
1087 | auto baseType = recurseAndWrapBaseType(indexExpression); |
1088 | if (!baseType) |
1089 | return; |
1090 | |
1091 | auto indexInfo = recurseAndGetInfo(indexExpression.indexExpression()); |
1092 | if (!indexInfo) |
1093 | return; |
1094 | auto indexExpressionType = getUnnamedType(indexInfo->resolvingType); |
1095 | if (!indexExpressionType) { |
1096 | setError(); |
1097 | return; |
1098 | } |
1099 | |
1100 | finishVisitingPropertyAccess(indexExpression, WTFMove(*baseType), indexExpressionType); |
1101 | } |
1102 | |
1103 | void Checker::visit(AST::VariableReference& variableReference) |
1104 | { |
1105 | ASSERT(variableReference.variable()); |
1106 | ASSERT(variableReference.variable()->type()); |
1107 | |
1108 | Optional<AST::AddressSpace> addressSpace; |
1109 | if (!variableReference.variable()->isAnonymous()) |
1110 | addressSpace = AST::AddressSpace::Thread; |
1111 | assignType(variableReference, variableReference.variable()->type()->clone(), addressSpace); |
1112 | } |
1113 | |
1114 | void Checker::visit(AST::Return& returnStatement) |
1115 | { |
1116 | ASSERT(returnStatement.function()); |
1117 | if (returnStatement.value()) { |
1118 | auto valueInfo = recurseAndGetInfo(*returnStatement.value()); |
1119 | if (!valueInfo) |
1120 | return; |
1121 | if (!matchAndCommit(valueInfo->resolvingType, returnStatement.function()->type())) |
1122 | setError(); |
1123 | return; |
1124 | } |
1125 | |
1126 | if (!matches(returnStatement.function()->type(), m_intrinsics.voidType())) |
1127 | setError(); |
1128 | } |
1129 | |
1130 | void Checker::visit(AST::PointerType&) |
1131 | { |
1132 | // Following pointer types can cause infinite loops because of data structures |
1133 | // like linked lists. |
1134 | // FIXME: Make sure this function should be empty |
1135 | } |
1136 | |
1137 | void Checker::visit(AST::ArrayReferenceType&) |
1138 | { |
1139 | // Following array reference types can cause infinite loops because of data |
1140 | // structures like linked lists. |
1141 | // FIXME: Make sure this function should be empty |
1142 | } |
1143 | |
1144 | void Checker::visit(AST::IntegerLiteral& integerLiteral) |
1145 | { |
1146 | assignType(integerLiteral, adoptRef(*new ResolvableTypeReference(integerLiteral.type()))); |
1147 | } |
1148 | |
1149 | void Checker::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) |
1150 | { |
1151 | assignType(unsignedIntegerLiteral, adoptRef(*new ResolvableTypeReference(unsignedIntegerLiteral.type()))); |
1152 | } |
1153 | |
1154 | void Checker::visit(AST::FloatLiteral& floatLiteral) |
1155 | { |
1156 | assignType(floatLiteral, adoptRef(*new ResolvableTypeReference(floatLiteral.type()))); |
1157 | } |
1158 | |
1159 | void Checker::visit(AST::NullLiteral& nullLiteral) |
1160 | { |
1161 | assignType(nullLiteral, adoptRef(*new ResolvableTypeReference(nullLiteral.type()))); |
1162 | } |
1163 | |
1164 | void Checker::visit(AST::BooleanLiteral& booleanLiteral) |
1165 | { |
1166 | assignType(booleanLiteral, AST::TypeReference::wrap(Lexer::Token(booleanLiteral.origin()), m_intrinsics.boolType())); |
1167 | } |
1168 | |
1169 | void Checker::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral) |
1170 | { |
1171 | ASSERT(enumerationMemberLiteral.enumerationDefinition()); |
1172 | auto& enumerationDefinition = *enumerationMemberLiteral.enumerationDefinition(); |
1173 | assignType(enumerationMemberLiteral, AST::TypeReference::wrap(Lexer::Token(enumerationMemberLiteral.origin()), enumerationDefinition)); |
1174 | } |
1175 | |
1176 | bool Checker::isBoolType(ResolvingType& resolvingType) |
1177 | { |
1178 | return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> bool { |
1179 | return matches(left, m_intrinsics.boolType()); |
1180 | }, [&](RefPtr<ResolvableTypeReference>& left) -> bool { |
1181 | return static_cast<bool>(matchAndCommit(m_intrinsics.boolType(), left->resolvableType())); |
1182 | })); |
1183 | } |
1184 | |
1185 | bool Checker::recurseAndRequireBoolType(AST::Expression& expression) |
1186 | { |
1187 | auto expressionInfo = recurseAndGetInfo(expression); |
1188 | if (!expressionInfo) |
1189 | return false; |
1190 | if (!isBoolType(expressionInfo->resolvingType)) { |
1191 | setError(); |
1192 | return false; |
1193 | } |
1194 | return true; |
1195 | } |
1196 | |
1197 | void Checker::visit(AST::LogicalNotExpression& logicalNotExpression) |
1198 | { |
1199 | if (!recurseAndRequireBoolType(logicalNotExpression.operand())) |
1200 | return; |
1201 | assignType(logicalNotExpression, AST::TypeReference::wrap(Lexer::Token(logicalNotExpression.origin()), m_intrinsics.boolType())); |
1202 | } |
1203 | |
1204 | void Checker::visit(AST::LogicalExpression& logicalExpression) |
1205 | { |
1206 | if (!recurseAndRequireBoolType(logicalExpression.left())) |
1207 | return; |
1208 | if (!recurseAndRequireBoolType(logicalExpression.right())) |
1209 | return; |
1210 | assignType(logicalExpression, AST::TypeReference::wrap(Lexer::Token(logicalExpression.origin()), m_intrinsics.boolType())); |
1211 | } |
1212 | |
1213 | void Checker::visit(AST::IfStatement& ifStatement) |
1214 | { |
1215 | if (!recurseAndRequireBoolType(ifStatement.conditional())) |
1216 | return; |
1217 | checkErrorAndVisit(ifStatement.body()); |
1218 | if (ifStatement.elseBody()) |
1219 | checkErrorAndVisit(*ifStatement.elseBody()); |
1220 | } |
1221 | |
1222 | void Checker::visit(AST::WhileLoop& whileLoop) |
1223 | { |
1224 | if (!recurseAndRequireBoolType(whileLoop.conditional())) |
1225 | return; |
1226 | checkErrorAndVisit(whileLoop.body()); |
1227 | } |
1228 | |
1229 | void Checker::visit(AST::DoWhileLoop& doWhileLoop) |
1230 | { |
1231 | checkErrorAndVisit(doWhileLoop.body()); |
1232 | recurseAndRequireBoolType(doWhileLoop.conditional()); |
1233 | } |
1234 | |
1235 | void Checker::visit(AST::ForLoop& forLoop) |
1236 | { |
1237 | WTF::visit(WTF::makeVisitor([&](AST::VariableDeclarationsStatement& variableDeclarationsStatement) { |
1238 | checkErrorAndVisit(variableDeclarationsStatement); |
1239 | }, [&](UniqueRef<AST::Expression>& expression) { |
1240 | checkErrorAndVisit(expression); |
1241 | }), forLoop.initialization()); |
1242 | if (error()) |
1243 | return; |
1244 | if (forLoop.condition()) { |
1245 | if (!recurseAndRequireBoolType(*forLoop.condition())) |
1246 | return; |
1247 | } |
1248 | if (forLoop.increment()) |
1249 | checkErrorAndVisit(*forLoop.increment()); |
1250 | checkErrorAndVisit(forLoop.body()); |
1251 | } |
1252 | |
1253 | void Checker::visit(AST::SwitchStatement& switchStatement) |
1254 | { |
1255 | auto* valueType = ([&]() -> AST::NamedType* { |
1256 | auto valueInfo = recurseAndGetInfo(switchStatement.value()); |
1257 | if (!valueInfo) |
1258 | return nullptr; |
1259 | auto* valueType = getUnnamedType(valueInfo->resolvingType); |
1260 | if (!valueType) |
1261 | return nullptr; |
1262 | auto& valueUnifyNode = valueType->unifyNode(); |
1263 | if (!is<AST::NamedType>(valueUnifyNode)) |
1264 | return nullptr; |
1265 | auto& valueNamedUnifyNode = downcast<AST::NamedType>(valueUnifyNode); |
1266 | if (!(is<AST::NativeTypeDeclaration>(valueNamedUnifyNode) && downcast<AST::NativeTypeDeclaration>(valueNamedUnifyNode).isInt()) |
1267 | && !is<AST::EnumerationDefinition>(valueNamedUnifyNode)) |
1268 | return nullptr; |
1269 | return &valueNamedUnifyNode; |
1270 | })(); |
1271 | if (!valueType) { |
1272 | setError(); |
1273 | return; |
1274 | } |
1275 | |
1276 | bool hasDefault = false; |
1277 | for (auto& switchCase : switchStatement.switchCases()) { |
1278 | checkErrorAndVisit(switchCase.block()); |
1279 | if (!switchCase.value()) { |
1280 | hasDefault = true; |
1281 | continue; |
1282 | } |
1283 | bool success; |
1284 | switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) { |
1285 | success = static_cast<bool>(matchAndCommit(*valueType, integerLiteral.type())); |
1286 | }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) { |
1287 | success = static_cast<bool>(matchAndCommit(*valueType, unsignedIntegerLiteral.type())); |
1288 | }, [&](AST::FloatLiteral& floatLiteral) { |
1289 | success = static_cast<bool>(matchAndCommit(*valueType, floatLiteral.type())); |
1290 | }, [&](AST::NullLiteral& nullLiteral) { |
1291 | success = static_cast<bool>(matchAndCommit(*valueType, nullLiteral.type())); |
1292 | }, [&](AST::BooleanLiteral&) { |
1293 | success = matches(*valueType, m_intrinsics.boolType()); |
1294 | }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) { |
1295 | ASSERT(enumerationMemberLiteral.enumerationDefinition()); |
1296 | success = matches(*valueType, *enumerationMemberLiteral.enumerationDefinition()); |
1297 | })); |
1298 | if (!success) { |
1299 | setError(); |
1300 | return; |
1301 | } |
1302 | } |
1303 | |
1304 | for (size_t i = 0; i < switchStatement.switchCases().size(); ++i) { |
1305 | auto& firstCase = switchStatement.switchCases()[i]; |
1306 | for (size_t j = i + 1; j < switchStatement.switchCases().size(); ++j) { |
1307 | auto& secondCase = switchStatement.switchCases()[j]; |
1308 | |
1309 | if (static_cast<bool>(firstCase.value()) != static_cast<bool>(secondCase.value())) |
1310 | continue; |
1311 | |
1312 | if (!static_cast<bool>(firstCase.value())) { |
1313 | setError(); |
1314 | return; |
1315 | } |
1316 | |
1317 | bool success = true; |
1318 | firstCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& firstIntegerLiteral) { |
1319 | secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) { |
1320 | success = firstIntegerLiteral.value() != secondIntegerLiteral.value(); |
1321 | }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) { |
1322 | success = static_cast<int64_t>(firstIntegerLiteral.value()) != static_cast<int64_t>(secondUnsignedIntegerLiteral.value()); |
1323 | }, [](auto&) { |
1324 | })); |
1325 | }, [&](AST::UnsignedIntegerLiteral& firstUnsignedIntegerLiteral) { |
1326 | secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) { |
1327 | success = static_cast<int64_t>(firstUnsignedIntegerLiteral.value()) != static_cast<int64_t>(secondIntegerLiteral.value()); |
1328 | }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) { |
1329 | success = firstUnsignedIntegerLiteral.value() != secondUnsignedIntegerLiteral.value(); |
1330 | }, [](auto&) { |
1331 | })); |
1332 | }, [&](AST::EnumerationMemberLiteral& firstEnumerationMemberLiteral) { |
1333 | secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral&) { |
1334 | }, [&](AST::EnumerationMemberLiteral& secondEnumerationMemberLiteral) { |
1335 | ASSERT(firstEnumerationMemberLiteral.enumerationMember()); |
1336 | ASSERT(secondEnumerationMemberLiteral.enumerationMember()); |
1337 | success = firstEnumerationMemberLiteral.enumerationMember() != secondEnumerationMemberLiteral.enumerationMember(); |
1338 | }, [](auto&) { |
1339 | })); |
1340 | }, [](auto&) { |
1341 | })); |
1342 | } |
1343 | } |
1344 | |
1345 | if (!hasDefault) { |
1346 | if (is<AST::NativeTypeDeclaration>(*valueType)) { |
1347 | HashSet<int64_t> values; |
1348 | bool zeroValueExists; |
1349 | for (auto& switchCase : switchStatement.switchCases()) { |
1350 | int64_t value; |
1351 | switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) { |
1352 | value = integerLiteral.valueForSelectedType(); |
1353 | }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) { |
1354 | value = unsignedIntegerLiteral.valueForSelectedType(); |
1355 | }, [](auto&) { |
1356 | ASSERT_NOT_REACHED(); |
1357 | })); |
1358 | if (!value) |
1359 | zeroValueExists = true; |
1360 | else |
1361 | values.add(value); |
1362 | } |
1363 | bool success = true; |
1364 | downcast<AST::NativeTypeDeclaration>(*valueType).iterateAllValues([&](int64_t value) -> bool { |
1365 | if (!value) { |
1366 | if (!zeroValueExists) { |
1367 | success = false; |
1368 | return true; |
1369 | } |
1370 | return false; |
1371 | } |
1372 | if (!values.contains(value)) { |
1373 | success = false; |
1374 | return true; |
1375 | } |
1376 | return false; |
1377 | }); |
1378 | if (!success) { |
1379 | setError(); |
1380 | return; |
1381 | } |
1382 | } else { |
1383 | ASSERT(is<AST::EnumerationDefinition>(*valueType)); |
1384 | HashSet<AST::EnumerationMember*> values; |
1385 | for (auto& switchCase : switchStatement.switchCases()) { |
1386 | switchCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) { |
1387 | ASSERT(enumerationMemberLiteral.enumerationMember()); |
1388 | values.add(enumerationMemberLiteral.enumerationMember()); |
1389 | }, [](auto&) { |
1390 | ASSERT_NOT_REACHED(); |
1391 | })); |
1392 | } |
1393 | for (auto& enumerationMember : downcast<AST::EnumerationDefinition>(*valueType).enumerationMembers()) { |
1394 | if (!values.contains(&enumerationMember.get())) { |
1395 | setError(); |
1396 | return; |
1397 | } |
1398 | } |
1399 | } |
1400 | } |
1401 | } |
1402 | |
1403 | void Checker::visit(AST::CommaExpression& commaExpression) |
1404 | { |
1405 | ASSERT(commaExpression.list().size() > 0); |
1406 | Visitor::visit(commaExpression); |
1407 | if (error()) |
1408 | return; |
1409 | auto lastInfo = getInfo(commaExpression.list().last()); |
1410 | forwardType(commaExpression, lastInfo->resolvingType); |
1411 | } |
1412 | |
1413 | void Checker::visit(AST::TernaryExpression& ternaryExpression) |
1414 | { |
1415 | auto predicateInfo = recurseAndRequireBoolType(ternaryExpression.predicate()); |
1416 | if (!predicateInfo) |
1417 | return; |
1418 | |
1419 | auto bodyInfo = recurseAndGetInfo(ternaryExpression.bodyExpression()); |
1420 | auto elseInfo = recurseAndGetInfo(ternaryExpression.elseExpression()); |
1421 | |
1422 | auto resultType = matchAndCommit(bodyInfo->resolvingType, elseInfo->resolvingType); |
1423 | if (!resultType) { |
1424 | setError(); |
1425 | return; |
1426 | } |
1427 | |
1428 | assignType(ternaryExpression, WTFMove(*resultType)); |
1429 | } |
1430 | |
1431 | void Checker::visit(AST::CallExpression& callExpression) |
1432 | { |
1433 | Vector<std::reference_wrapper<ResolvingType>> types; |
1434 | types.reserveInitialCapacity(callExpression.arguments().size()); |
1435 | for (auto& argument : callExpression.arguments()) { |
1436 | auto argumentInfo = recurseAndGetInfo(argument); |
1437 | if (!argumentInfo) |
1438 | return; |
1439 | types.uncheckedAppend(argumentInfo->resolvingType); |
1440 | } |
1441 | // Don't recurse on the castReturnType, because it's guaranteed to be a NamedType, which will get visited later. |
1442 | // We don't want to recurse to the same node twice. |
1443 | |
1444 | ASSERT(callExpression.hasOverloads()); |
1445 | auto* function = resolveFunctionOverloadImpl(*callExpression.overloads(), types, callExpression.castReturnType()); |
1446 | if (!function) { |
1447 | if (auto newFunction = resolveByInstantiation(callExpression, types, m_intrinsics)) { |
1448 | m_program.append(WTFMove(*newFunction)); |
1449 | function = &m_program.nativeFunctionDeclarations().last(); |
1450 | } |
1451 | } |
1452 | |
1453 | if (!function) { |
1454 | setError(); |
1455 | return; |
1456 | } |
1457 | |
1458 | for (size_t i = 0; i < function->parameters().size(); ++i) { |
1459 | if (!matchAndCommit(types[i].get(), *function->parameters()[i].type())) { |
1460 | setError(); |
1461 | return; |
1462 | } |
1463 | } |
1464 | |
1465 | callExpression.setFunction(*function); |
1466 | |
1467 | assignType(callExpression, function->type().clone()); |
1468 | } |
1469 | |
1470 | bool check(Program& program) |
1471 | { |
1472 | Checker checker(program.intrinsics(), program); |
1473 | checker.checkErrorAndVisit(program); |
1474 | if (checker.error()) |
1475 | return false; |
1476 | return checker.assignTypes(); |
1477 | } |
1478 | |
1479 | } // namespace WHLSL |
1480 | |
1481 | } // namespace WebCore |
1482 | |
1483 | #endif // ENABLE(WEBGPU) |
1484 | |