1/*
2 * Copyright (C) 2019 Apple Inc. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 *
13 * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23 * THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26#include "config.h"
27#include "WHLSLSemanticMatcher.h"
28
29#if ENABLE(WEBGPU)
30
31#include "WHLSLBuiltInSemantic.h"
32#include "WHLSLFunctionDefinition.h"
33#include "WHLSLGatherEntryPointItems.h"
34#include "WHLSLInferTypes.h"
35#include "WHLSLPipelineDescriptor.h"
36#include "WHLSLProgram.h"
37#include "WHLSLResourceSemantic.h"
38#include "WHLSLStageInOutSemantic.h"
39#include <wtf/HashMap.h>
40#include <wtf/HashSet.h>
41#include <wtf/Optional.h>
42#include <wtf/text/WTFString.h>
43
44namespace WebCore {
45
46namespace WHLSL {
47
48static AST::FunctionDefinition* findEntryPoint(Vector<UniqueRef<AST::FunctionDefinition>>& functionDefinitions, String& name)
49{
50 auto iterator = std::find_if(functionDefinitions.begin(), functionDefinitions.end(), [&](AST::FunctionDefinition& functionDefinition) {
51 return functionDefinition.entryPointType() && functionDefinition.name() == name;
52 });
53 if (iterator == functionDefinitions.end())
54 return nullptr;
55 return &*iterator;
56};
57
58static bool matchMode(BindingType bindingType, AST::ResourceSemantic::Mode mode)
59{
60 switch (bindingType) {
61 case BindingType::UniformBuffer:
62 return mode == AST::ResourceSemantic::Mode::Buffer;
63 case BindingType::Sampler:
64 return mode == AST::ResourceSemantic::Mode::Sampler;
65 case BindingType::Texture:
66 return mode == AST::ResourceSemantic::Mode::Texture;
67 default:
68 ASSERT(bindingType == BindingType::StorageBuffer);
69 return mode == AST::ResourceSemantic::Mode::UnorderedAccessView;
70 }
71}
72
73static Optional<HashMap<Binding*, size_t>> matchResources(Vector<EntryPointItem>& entryPointItems, Layout& layout, ShaderStage shaderStage)
74{
75 HashMap<Binding*, size_t> result;
76 HashSet<size_t> itemIndices;
77 if (entryPointItems.size() == std::numeric_limits<size_t>::max())
78 return WTF::nullopt; // Work around the fact that HashSet's keys are restricted.
79 for (auto& bindGroup : layout) {
80 auto space = bindGroup.name;
81 for (auto& binding : bindGroup.bindings) {
82 if (!binding.visibility.contains(shaderStage))
83 continue;
84 for (size_t i = 0; i < entryPointItems.size(); ++i) {
85 auto& item = entryPointItems[i];
86 auto& semantic = *item.semantic;
87 if (!WTF::holds_alternative<AST::ResourceSemantic>(semantic))
88 continue;
89 auto& resourceSemantic = WTF::get<AST::ResourceSemantic>(semantic);
90 if (!matchMode(binding.bindingType, resourceSemantic.mode()))
91 continue;
92 if (binding.name != resourceSemantic.index())
93 continue;
94 if (space != resourceSemantic.space())
95 continue;
96 result.add(&binding, i);
97 itemIndices.add(i + 1); // Work around the fact that HashSet's keys are restricted.
98 }
99 }
100 }
101
102 for (size_t i = 0; i < entryPointItems.size(); ++i) {
103 auto& item = entryPointItems[i];
104 auto& semantic = *item.semantic;
105 if (!WTF::holds_alternative<AST::ResourceSemantic>(semantic))
106 continue;
107 if (!itemIndices.contains(i + 1))
108 return WTF::nullopt;
109 }
110
111 return result;
112}
113
114static bool matchInputsOutputs(Vector<EntryPointItem>& vertexOutputs, Vector<EntryPointItem>& fragmentInputs)
115{
116 for (auto& fragmentInput : fragmentInputs) {
117 if (!WTF::holds_alternative<AST::StageInOutSemantic>(*fragmentInput.semantic))
118 continue;
119 auto& fragmentInputStageInOutSemantic = WTF::get<AST::StageInOutSemantic>(*fragmentInput.semantic);
120 bool found = false;
121 for (auto& vertexOutput : vertexOutputs) {
122 if (!WTF::holds_alternative<AST::StageInOutSemantic>(*vertexOutput.semantic))
123 continue;
124 auto& vertexOutputStageInOutSemantic = WTF::get<AST::StageInOutSemantic>(*vertexOutput.semantic);
125 if (fragmentInputStageInOutSemantic.index() == vertexOutputStageInOutSemantic.index()) {
126 if (matches(*fragmentInput.unnamedType, *vertexOutput.unnamedType)) {
127 found = true;
128 break;
129 }
130 return false;
131 }
132 }
133 if (!found)
134 return false;
135 }
136 return true;
137}
138
139static bool isAcceptableFormat(VertexFormat vertexFormat, AST::UnnamedType& unnamedType, Intrinsics& intrinsics)
140{
141 switch (vertexFormat) {
142 case VertexFormat::FloatR32G32B32A32:
143 return matches(unnamedType, intrinsics.float4Type());
144 case VertexFormat::FloatR32G32B32:
145 return matches(unnamedType, intrinsics.float3Type());
146 case VertexFormat::FloatR32G32:
147 return matches(unnamedType, intrinsics.float2Type());
148 default:
149 ASSERT(vertexFormat == VertexFormat::FloatR32);
150 return matches(unnamedType, intrinsics.floatType());
151 }
152}
153
154static Optional<HashMap<VertexAttribute*, size_t>> matchVertexAttributes(Vector<EntryPointItem>& vertexInputs, VertexAttributes& vertexAttributes, Intrinsics& intrinsics)
155{
156 HashMap<VertexAttribute*, size_t> result;
157 HashSet<size_t> itemIndices;
158 if (vertexInputs.size() == std::numeric_limits<size_t>::max())
159 return WTF::nullopt; // Work around the fact that HashSet's keys are restricted.
160 for (auto& vertexAttribute : vertexAttributes) {
161 for (size_t i = 0; i < vertexInputs.size(); ++i) {
162 auto& item = vertexInputs[i];
163 auto& semantic = *item.semantic;
164 if (!WTF::holds_alternative<AST::StageInOutSemantic>(semantic))
165 continue;
166 auto& stageInOutSemantic = WTF::get<AST::StageInOutSemantic>(semantic);
167 if (stageInOutSemantic.index() != vertexAttribute.name)
168 continue;
169 if (!isAcceptableFormat(vertexAttribute.vertexFormat, *item.unnamedType, intrinsics))
170 return WTF::nullopt;
171 result.add(&vertexAttribute, i);
172 itemIndices.add(i + 1); // Work around the fact that HashSet's keys are restricted.
173 }
174 }
175
176 for (size_t i = 0; i < vertexInputs.size(); ++i) {
177 auto& item = vertexInputs[i];
178 auto& semantic = *item.semantic;
179 if (!WTF::holds_alternative<AST::StageInOutSemantic>(semantic))
180 continue;
181 if (!itemIndices.contains(i + 1))
182 return WTF::nullopt;
183 }
184
185 return result;
186}
187
188static bool isAcceptableFormat(TextureFormat textureFormat, AST::UnnamedType& unnamedType, Intrinsics& intrinsics, bool isColor)
189{
190 if (isColor) {
191 switch (textureFormat) {
192 case TextureFormat::R8Unorm:
193 case TextureFormat::R8UnormSrgb:
194 case TextureFormat::R8Snorm:
195 case TextureFormat::R16Unorm:
196 case TextureFormat::R16Snorm:
197 case TextureFormat::R16Float:
198 case TextureFormat::R32Float:
199 return matches(unnamedType, intrinsics.floatType());
200 case TextureFormat::RG8Unorm:
201 case TextureFormat::RG8UnormSrgb:
202 case TextureFormat::RG8Snorm:
203 case TextureFormat::RG16Unorm:
204 case TextureFormat::RG16Snorm:
205 case TextureFormat::RG16Float:
206 case TextureFormat::RG32Float:
207 return matches(unnamedType, intrinsics.float2Type());
208 case TextureFormat::B5G6R5Unorm:
209 case TextureFormat::RG11B10Float:
210 return matches(unnamedType, intrinsics.float3Type());
211 case TextureFormat::RGBA8Unorm:
212 case TextureFormat::RGBA8UnormSrgb:
213 case TextureFormat::BGRA8Unorm:
214 case TextureFormat::BGRA8UnormSrgb:
215 case TextureFormat::RGBA8Snorm:
216 case TextureFormat::RGB10A2Unorm:
217 case TextureFormat::RGBA16Unorm:
218 case TextureFormat::RGBA16Snorm:
219 case TextureFormat::RGBA16Float:
220 case TextureFormat::RGBA32Float:
221 return matches(unnamedType, intrinsics.float4Type());
222 case TextureFormat::R8Uint:
223 return matches(unnamedType, intrinsics.ucharType());
224 case TextureFormat::R8Sint:
225 return matches(unnamedType, intrinsics.charType());
226 case TextureFormat::R16Uint:
227 return matches(unnamedType, intrinsics.ushortType());
228 case TextureFormat::R16Sint:
229 return matches(unnamedType, intrinsics.shortType());
230 case TextureFormat::R32Uint:
231 return matches(unnamedType, intrinsics.uintType());
232 case TextureFormat::R32Sint:
233 return matches(unnamedType, intrinsics.intType());
234 case TextureFormat::RG8Uint:
235 return matches(unnamedType, intrinsics.uchar2Type());
236 case TextureFormat::RG8Sint:
237 return matches(unnamedType, intrinsics.char2Type());
238 case TextureFormat::RG16Uint:
239 return matches(unnamedType, intrinsics.ushort2Type());
240 case TextureFormat::RG16Sint:
241 return matches(unnamedType, intrinsics.short2Type());
242 case TextureFormat::RG32Uint:
243 return matches(unnamedType, intrinsics.uint2Type());
244 case TextureFormat::RG32Sint:
245 return matches(unnamedType, intrinsics.int2Type());
246 case TextureFormat::RGBA8Uint:
247 return matches(unnamedType, intrinsics.uchar4Type());
248 case TextureFormat::RGBA8Sint:
249 return matches(unnamedType, intrinsics.char4Type());
250 case TextureFormat::RGBA16Uint:
251 return matches(unnamedType, intrinsics.ushort4Type());
252 case TextureFormat::RGBA16Sint:
253 return matches(unnamedType, intrinsics.short4Type());
254 case TextureFormat::RGBA32Uint:
255 return matches(unnamedType, intrinsics.uint4Type());
256 case TextureFormat::RGBA32Sint:
257 return matches(unnamedType, intrinsics.int4Type());
258 default:
259 ASSERT_NOT_REACHED();
260 return false;
261 }
262 }
263 return false;
264}
265
266static Optional<HashMap<AttachmentDescriptor*, size_t>> matchColorAttachments(Vector<EntryPointItem>& fragmentOutputs, Vector<AttachmentDescriptor>& attachmentDescriptors, Intrinsics& intrinsics)
267{
268 HashMap<AttachmentDescriptor*, size_t> result;
269 HashSet<size_t> itemIndices;
270 if (attachmentDescriptors.size() == std::numeric_limits<size_t>::max())
271 return WTF::nullopt; // Work around the fact that HashSet's keys are restricted.
272 for (auto& attachmentDescriptor : attachmentDescriptors) {
273 for (size_t i = 0; i < fragmentOutputs.size(); ++i) {
274 auto& item = fragmentOutputs[i];
275 auto& semantic = *item.semantic;
276 if (!WTF::holds_alternative<AST::StageInOutSemantic>(semantic))
277 continue;
278 auto& stageInOutSemantic = WTF::get<AST::StageInOutSemantic>(semantic);
279 if (stageInOutSemantic.index() != attachmentDescriptor.name)
280 continue;
281 if (!isAcceptableFormat(attachmentDescriptor.textureFormat, *item.unnamedType, intrinsics, true))
282 return WTF::nullopt;
283 result.add(&attachmentDescriptor, i);
284 itemIndices.add(i + 1); // Work around the fact that HashSet's keys are restricted.
285 }
286 }
287
288 for (size_t i = 0; i < fragmentOutputs.size(); ++i) {
289 auto& item = fragmentOutputs[i];
290 auto& semantic = *item.semantic;
291 if (!WTF::holds_alternative<AST::StageInOutSemantic>(semantic))
292 continue;
293 if (!itemIndices.contains(i + 1))
294 return WTF::nullopt;
295 }
296
297 return result;
298}
299
300static bool matchDepthAttachment(Vector<EntryPointItem>& fragmentOutputs, Optional<AttachmentDescriptor>& depthStencilAttachmentDescriptor, Intrinsics& intrinsics)
301{
302 auto iterator = std::find_if(fragmentOutputs.begin(), fragmentOutputs.end(), [&](EntryPointItem& item) {
303 auto& semantic = *item.semantic;
304 if (!WTF::holds_alternative<AST::BuiltInSemantic>(semantic))
305 return false;
306 auto& builtInSemantic = WTF::get<AST::BuiltInSemantic>(semantic);
307 return builtInSemantic.variable() == AST::BuiltInSemantic::Variable::SVDepth;
308 });
309 if (iterator == fragmentOutputs.end())
310 return true;
311
312 if (depthStencilAttachmentDescriptor) {
313 ASSERT(!depthStencilAttachmentDescriptor->name);
314 return isAcceptableFormat(depthStencilAttachmentDescriptor->textureFormat, *iterator->unnamedType, intrinsics, false);
315 }
316 return false;
317}
318
319Optional<MatchedRenderSemantics> matchSemantics(Program& program, RenderPipelineDescriptor& renderPipelineDescriptor)
320{
321 auto vertexShaderEntryPoint = findEntryPoint(program.functionDefinitions(), renderPipelineDescriptor.vertexEntryPointName);
322 auto fragmentShaderEntryPoint = findEntryPoint(program.functionDefinitions(), renderPipelineDescriptor.fragmentEntryPointName);
323 if (!vertexShaderEntryPoint || !fragmentShaderEntryPoint)
324 return WTF::nullopt;
325 auto vertexShaderEntryPointItems = gatherEntryPointItems(program.intrinsics(), *vertexShaderEntryPoint);
326 auto fragmentShaderEntryPointItems = gatherEntryPointItems(program.intrinsics(), *fragmentShaderEntryPoint);
327 if (!vertexShaderEntryPointItems || !fragmentShaderEntryPointItems)
328 return WTF::nullopt;
329 auto vertexShaderResourceMap = matchResources(vertexShaderEntryPointItems->inputs, renderPipelineDescriptor.layout, ShaderStage::Vertex);
330 auto fragmentShaderResourceMap = matchResources(fragmentShaderEntryPointItems->inputs, renderPipelineDescriptor.layout, ShaderStage::Fragment);
331 if (!vertexShaderResourceMap || !fragmentShaderResourceMap)
332 return WTF::nullopt;
333 if (!matchInputsOutputs(vertexShaderEntryPointItems->outputs, fragmentShaderEntryPointItems->inputs))
334 return WTF::nullopt;
335 auto matchedVertexAttributes = matchVertexAttributes(vertexShaderEntryPointItems->inputs, renderPipelineDescriptor.vertexAttributes, program.intrinsics());
336 if (!matchedVertexAttributes)
337 return WTF::nullopt;
338 auto matchedColorAttachments = matchColorAttachments(fragmentShaderEntryPointItems->outputs, renderPipelineDescriptor.attachmentsStateDescriptor.attachmentDescriptors, program.intrinsics());
339 if (!matchedColorAttachments)
340 return WTF::nullopt;
341 if (!matchDepthAttachment(fragmentShaderEntryPointItems->outputs, renderPipelineDescriptor.attachmentsStateDescriptor.depthStencilAttachmentDescriptor, program.intrinsics()))
342 return WTF::nullopt;
343 return {{ vertexShaderEntryPoint, fragmentShaderEntryPoint, *vertexShaderEntryPointItems, *fragmentShaderEntryPointItems, *vertexShaderResourceMap, *fragmentShaderResourceMap, *matchedVertexAttributes, *matchedColorAttachments }};
344}
345
346Optional<MatchedComputeSemantics> matchSemantics(Program& program, ComputePipelineDescriptor& computePipelineDescriptor)
347{
348 auto entryPoint = findEntryPoint(program.functionDefinitions(), computePipelineDescriptor.entryPointName);
349 if (!entryPoint)
350 return WTF::nullopt;
351 auto entryPointItems = gatherEntryPointItems(program.intrinsics(), *entryPoint);
352 if (!entryPointItems)
353 return WTF::nullopt;
354 ASSERT(entryPointItems->outputs.isEmpty());
355 auto resourceMap = matchResources(entryPointItems->inputs, computePipelineDescriptor.layout, ShaderStage::Compute);
356 if (!resourceMap)
357 return WTF::nullopt;
358 return {{ entryPoint, *entryPointItems, *resourceMap }};
359}
360
361} // namespace WHLSL
362
363} // namespace WebCore
364
365#endif // ENABLE(WEBGPU)
366