-
Notifications
You must be signed in to change notification settings - Fork 0
/
atom.xml
585 lines (346 loc) · 933 KB
/
atom.xml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
<title>Don't Respond</title>
<link href="https://hjchen2.github.io/atom.xml" rel="self"/>
<link href="https://hjchen2.github.io/"/>
<updated>2023-03-07T06:09:20.211Z</updated>
<id>https://hjchen2.github.io/</id>
<author>
<name>Dou Jiang</name>
</author>
<generator uri="https://hexo.io/">Hexo</generator>
<entry>
<title>IREE编译流程解析(六)</title>
<link href="https://hjchen2.github.io/2023/02/24/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B6/"/>
<id>https://hjchen2.github.io/2023/02/24/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B6/</id>
<published>2023-02-24T10:47:11.000Z</published>
<updated>2023-03-07T06:09:20.211Z</updated>
<content type="html"><![CDATA[<p>HAL::HALTransformPassPipeline的主要作用是进行tiling、vectorization和bufferization等操作,分配计算负载,最终生成targetdevice的代码。比如cuda target的dispatch source code会被递降为NVVMIR。</p><span id="more"></span><ul><li><p>buildHALConfigurationPassPipeline</p><ul><li><p>addCleanupPatterns</p></li><li><p>createAssignTargetDevicesPass</p><p>在最外层的module上添加device targets属性,可以指定多个targetdevices。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">module</span> attributes {hal.device.targets = [<span class="meta">#hal.device.target<span class="string"><"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}></span>], legacy_sync}>]} {</span></span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createVerifyTargetEnvironmentPass</p><p>验证device tagets是否正确设置,以及编译后端是否被注册过。</p></li><li><p>createMaterializeInterfacesPass</p><p>为每个executable创建device target相关的变体(variant),每一种devicetarget对应一个executable variant。将executable的export和sourcefunc都转换为无参数的func,统一dispatch、export和sourcefunc的调用接口,dispatch指定输入和bindings的关系,sourcefunc则通过binding id来获取输入参数。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br></pre></td><td class="code"><pre><span class="line">stream.executable <span class="keyword">private</span> @test_dispatch_0 {</span><br><span class="line"> stream.executable.<span class="keyword">export</span> <span class="keyword">public</span> @<span class="function">test_dispatch_0_generic_100000x100 <span class="title">workgroups</span><span class="params">(%arg0: index, %arg1: index)</span> -> <span class="params">(index, index, index)</span> </span>{</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1</span><br><span class="line"> stream.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> builtin.<span class="keyword">module</span> {</span><br><span class="line"> func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>(%arg0: !stream.binding {stream.alignment = <span class="number">64</span> : index}, %arg1: !stream.binding {stream.alignment = <span class="number">64</span> : index}, %arg2: !stream.binding {stream.alignment = <span class="number">64</span> : index}) {</span><br><span class="line"> ...</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> ...</span><br><span class="line"> %<span class="number">3</span> = stream.cmd.execute <span class="built_in">with</span>(%<span class="number">0</span> as %arg2: !stream.resource<external>{%c40000000}, %<span class="number">1</span> as %arg3: !stream.resource<external>{%c40000000}, %<span class="number">2</span> as %arg4: !stream.resource<external>{%c400000}) {</span><br><span class="line"> stream.cmd.fill %c0_i8, %arg4[%c0 <span class="keyword">for</span> %c400000] : i8 -> !stream.resource<external>{%c400000}</span><br><span class="line"> stream.cmd.dispatch @test_dispatch_0::@test_dispatch_0_generic_100000x100[%c100000, %c1] {</span><br><span class="line"> ro %arg2[%c0 <span class="keyword">for</span> %c40000000] : !stream.resource<external>{%c40000000},</span><br><span class="line"> ro %arg3[%c0 <span class="keyword">for</span> %c40000000] : !stream.resource<external>{%c40000000},</span><br><span class="line"> rw %arg4[%c0 <span class="keyword">for</span> %c400000] : !stream.resource<external>{%c400000}</span><br><span class="line"> }</span><br><span class="line"> } => !stream.timepoint</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换为</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br></pre></td><td class="code"><pre><span class="line">hal.executable <span class="keyword">private</span> @test_dispatch_0 {</span><br><span class="line"> hal.executable.variant <span class="keyword">public</span> @cuda_nvptx_fb, target = <<span class="string">"cuda"</span>, <span class="string">"cuda-nvptx-fb"</span>, {target_arch = <span class="string">"sm_35"</span>}> {</span><br><span class="line"> hal.executable.<span class="keyword">export</span> <span class="keyword">public</span> @test_dispatch_0_generic_100000x100 <span class="built_in">ordinal</span>(<span class="number">0</span>) <span class="built_in">layout</span>(<span class="meta">#hal.pipeline.layout<span class="string"><push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly></span>, <span class="string"><1, storage_buffer, ReadOnly></span>, <span class="string"><2, storage_buffer></span>]>]>) {</span></span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg0: !hal.device, %arg1: index, %arg2: index):</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2</span><br><span class="line"> hal.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> builtin.<span class="keyword">module</span> {</span><br><span class="line"> func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">0</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>></span><br><span class="line"> %<span class="number">1</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">1</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>></span><br><span class="line"> %<span class="number">2</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">2</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>></span><br><span class="line"> ...</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> ...</span><br><span class="line"> %<span class="number">3</span> = stream.cmd.execute <span class="built_in">with</span>(%<span class="number">0</span> as %arg2: !stream.resource<external>{%c40000000}, %<span class="number">1</span> as %arg3: !stream.resource<external>{%c40000000}, %<span class="number">2</span> as %arg4: !stream.resource<external>{%c400000}) {</span><br><span class="line"> stream.cmd.fill %c0_i8, %arg4[%c0 <span class="keyword">for</span> %c400000] : i8 -> !stream.resource<external>{%c400000}</span><br><span class="line"> stream.cmd.dispatch @test_dispatch_0::@test_dispatch_0_generic_100000x100[%c100000, %c1] {</span><br><span class="line"> ro %arg2[%c0 <span class="keyword">for</span> %c40000000] : !stream.resource<external>{%c40000000},</span><br><span class="line"> ro %arg3[%c0 <span class="keyword">for</span> %c40000000] : !stream.resource<external>{%c40000000},</span><br><span class="line"> rw %arg4[%c0 <span class="keyword">for</span> %c400000] : !stream.resource<external>{%c400000}</span><br><span class="line"> } attributes {hal.interface.bindings = [<span class="meta">#hal.interface.binding<span class="string"><0, 0></span>, #hal.interface.binding<span class="string"><0, 1></span>, #hal.interface.binding<span class="string"><0, 2></span>]}</span></span><br><span class="line"> } => !stream.timepoint</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li></ul></li><li><p>createTranslateExecutablesPass</p><p>根据每一个<code>hal.executable.variant</code> 的targetdevice调用对应的后端进行编译。比如cuda会调用CUDATargetBackend,CUDATargetBackend实际执行的是下面一序列passes。</p><ul><li><p>buildLLVMGPUTransformPassPipeline</p><ul><li><p>createTypePropagationPass</p><p>对integer的element type进行标准化,并传播修改过的type。</p></li><li><p>createBufferizeCopyOnlyDispatchesPass</p><p>将纯数据拷贝的dispatch(只有tensor load和store)转换成linalg genericop,并bufferize化。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0</span>() {</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.interface.constant.load[<span class="number">0</span>] : i32</span><br><span class="line"> %<span class="number">1</span> = arith.index_castui %<span class="number">0</span> : i32 to index</span><br><span class="line"> %<span class="number">2</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">0</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : !flow.dispatch.tensor<readonly:tensor<?xf32>>{%<span class="number">1</span>}</span><br><span class="line"> %<span class="number">3</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">1</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%<span class="number">1</span>}</span><br><span class="line"> %<span class="number">4</span> = flow.dispatch.tensor.load %<span class="number">2</span>, offsets = [<span class="number">0</span>], sizes = [%<span class="number">1</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<?xf32>>{%<span class="number">1</span>} -> tensor<?xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">4</span>, %<span class="number">3</span>, offsets = [<span class="number">0</span>], sizes = [%<span class="number">1</span>], strides = [<span class="number">1</span>] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%<span class="number">1</span>}</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0</span>() {</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.interface.constant.load[<span class="number">0</span>] : i32</span><br><span class="line"> %<span class="number">1</span> = arith.index_castui %<span class="number">0</span> : i32 to index</span><br><span class="line"> %<span class="number">2</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">0</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<?xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>>{%1}</span></span><br><span class="line"> memref.assume_alignment %<span class="number">2</span>, <span class="number">64</span> : memref<?xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> %<span class="number">3</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">1</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<?xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>>{%1}</span></span><br><span class="line"> memref.assume_alignment %<span class="number">3</span>, <span class="number">64</span> : memref<?xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">2</span> : memref<?xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>>) outs(%3 : memref<span class="string"><?xf32, #hal.descriptor_type<storage_buffer></span>>) {</span></span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %out: f32):</span><br><span class="line"> linalg.yield %in : f32</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createEraseHALDescriptorTypeFromMemRefPass</p><p>将memory space为hal descriptor type的value转换成memref。</p></li><li><p>createLLVMGPULowerExecutableTargetPass</p><ul><li><p>initGPULaunchConfig</p><p>根据具体的计算负载和类型,计算gpu launch的配置,包括分块策略、groupcount、thread num以及后续lowering分发的流程等。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line">hal.executable.variant <span class="keyword">public</span> @cuda_nvptx_fb, target = <<span class="string">"cuda"</span>, <span class="string">"cuda-nvptx-fb"</span>, {target_arch = <span class="string">"sm_35"</span>}> {</span><br><span class="line"> hal.executable.<span class="keyword">export</span> <span class="keyword">public</span> @test_dispatch_0_generic_100000x100 <span class="built_in">ordinal</span>(<span class="number">0</span>) <span class="built_in">layout</span>(<span class="meta">#hal.pipeline.layout<span class="string"><push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly></span>, <span class="string"><1, storage_buffer, ReadOnly></span>, <span class="string"><2, storage_buffer></span>]>]>) {</span></span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg0: !hal.device, %arg1: index, %arg2: index):</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2</span><br><span class="line"> hal.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> builtin.<span class="keyword">module</span> {</span><br><span class="line"> func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> ...</span><br><span class="line"> %<span class="number">6</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">3</span>, %<span class="number">4</span> : tensor<<span class="number">100000</span>x100xf32>, tensor<<span class="number">100000</span>x100xf32>) <span class="built_in">outs</span>(%<span class="number">5</span> : tensor<<span class="number">100000</span>xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">7</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> %<span class="number">8</span> = arith.addf %<span class="number">7</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">8</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">100000</span>xf32></span><br><span class="line"> ...</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line">hal.executable.variant <span class="keyword">public</span> @cuda_nvptx_fb, target = <<span class="string">"cuda"</span>, <span class="string">"cuda-nvptx-fb"</span>, {target_arch = <span class="string">"sm_35"</span>}> {</span><br><span class="line"> hal.executable.<span class="keyword">export</span> <span class="keyword">public</span> @test_dispatch_0_generic_100000x100 <span class="built_in">ordinal</span>(<span class="number">0</span>) <span class="built_in">layout</span>(<span class="meta">#hal.pipeline.layout<span class="string"><push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly></span>, <span class="string"><1, storage_buffer, ReadOnly></span>, <span class="string"><2, storage_buffer></span>]>]>) attributes {translation_info = #iree_codegen.translation_info<span class="string"><LLVMGPUVectorize></span>, workgroup_size = [64 : index, 1 : index, 1 : index]} {</span></span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg0: !hal.device, %arg1: index, %arg2: index):</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2</span><br><span class="line"> hal.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> builtin.<span class="keyword">module</span> {</span><br><span class="line"> func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> ...</span><br><span class="line"> %<span class="number">6</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">3</span>, %<span class="number">4</span> : tensor<<span class="number">100000</span>x100xf32>, tensor<<span class="number">100000</span>x100xf32>) <span class="built_in">outs</span>(%<span class="number">5</span> : tensor<<span class="number">100000</span>xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[<span class="number">256</span>, <span class="number">4</span>]]>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">7</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> %<span class="number">8</span> = arith.addf %<span class="number">7</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">8</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">100000</span>xf32></span><br><span class="line"> ...</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>可以看到exportfunc多了translation_info和workgroup_size两个属性,而sourcefunc也多了一个lowering_config属性。translation_info表示后续lowering分发到LLVMGPUVectorize这个pipeline。workgroup_size可以认为是3维的gpublockdim,这里表示每个线程块有64个线程。lowering_config指明了每层循环的分块策略,这里表示一个线程块计算256个100xf32的数据,而且每个线程一次计算一个4xf32的向量。</p></li><li><p>DispatchLoweringPassPipeline</p><p>根据translation_info分发到下面的pipeline继续lowering。</p><ul><li><p>GPUSimpleDistributePassPipeline</p></li><li><p>GPUVectorizationPassPipeline</p><ul><li><p>getTileAndDistributeConfig</p><p>定位到dispatch的root节点(一般是最后一个linalg reductionop,如果没有reduction op,则会选择最后一个linalg genericop),从节点属性中取出lowering_config(tile size),将非parallelloop对应的tile size置0,表示接下来只会对parallelloop进行vectorize,并计算parallel loop的loop range。</p></li><li><p>LowerDispatchWorkgroupCountForDagRootOp</p><p>根据loop range和tile size计算workgroup count。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">hal.executable.<span class="keyword">export</span> <span class="keyword">public</span> @<span class="function">test_dispatch_0_generic_100000x100 <span class="title">ordinal</span><span class="params">(<span class="number">0</span>)</span> <span class="title">layout</span><span class="params">(#hal.pipeline.layout<push_constants = <span class="number">0</span>, sets = [<<span class="number">0</span>, bindings = [<<span class="number">0</span>, storage_buffer, ReadOnly>, <<span class="number">1</span>, storage_buffer, ReadOnly>, <<span class="number">2</span>, storage_buffer>]>]>)</span> attributes </span>{translation_info = #iree_codegen.translation_info<LLVMGPUVectorize>, workgroup_size = [<span class="number">64</span> : index, <span class="number">1</span> : index, <span class="number">1</span> : index]} {</span><br><span class="line">^<span class="built_in">bb0</span>(%arg0: !hal.device, %arg1: index, %arg2: index):</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2</span><br><span class="line"> hal.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">hal.executable.<span class="keyword">export</span> <span class="keyword">public</span> @<span class="function">test_dispatch_0_generic_100000x100 <span class="title">ordinal</span><span class="params">(<span class="number">0</span>)</span> <span class="title">layout</span><span class="params">(#hal.pipeline.layout<push_constants = <span class="number">0</span>, sets = [<<span class="number">0</span>, bindings = [<<span class="number">0</span>, storage_buffer, ReadOnly>, <<span class="number">1</span>, storage_buffer, ReadOnly>, <<span class="number">2</span>, storage_buffer>]>]>)</span> attributes </span>{translation_info = #iree_codegen.translation_info<LLVMGPUVectorize>, workgroup_size = [<span class="number">64</span> : index, <span class="number">1</span> : index, <span class="number">1</span> : index]} {</span><br><span class="line">^<span class="built_in">bb0</span>(%arg0: !hal.device, %arg1: index, %arg2: index):</span><br><span class="line"> %c391 = arith.constant <span class="number">391</span> : index</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> hal.<span class="keyword">return</span> %c391, %c1, %c1 : index, index, index</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>可以看到计算的group count为(391, 1, 1)。391 = UDIV(100000,256)。</p></li><li><p>populateTileAndDistributeToWorkgroupsPatterns</p><p>对parallel loop进行分块,将source func转换成单个workgroup的计算逻辑。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> ...</span><br><span class="line"> %<span class="number">6</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">3</span>, %<span class="number">4</span> : tensor<<span class="number">100000</span>x100xf32>, tensor<<span class="number">100000</span>x100xf32>) <span class="built_in">outs</span>(%<span class="number">5</span> : tensor<<span class="number">100000</span>xf32>) attrs = {__internal_linalg_transform__ = <span class="string">"__workgroup_tiling__"</span>, lowering_config = #iree_codegen.lowering_config<tile_sizes = [[<span class="number">256</span>, <span class="number">4</span>]]>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">7</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> %<span class="number">8</span> = arith.addf %<span class="number">7</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">8</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">100000</span>xf32></span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> ...</span><br><span class="line"> %workgroup_id_x = hal.interface.workgroup.id[<span class="number">0</span>] : index</span><br><span class="line"> %workgroup_count_x = hal.interface.workgroup.count[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">3</span> = affine.apply <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">256</span>)>()[%workgroup_id_x]</span><br><span class="line"> %<span class="number">4</span> = affine.apply <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">256</span>)>()[%workgroup_count_x]</span><br><span class="line"> scf.<span class="keyword">for</span> %arg0 = %<span class="number">3</span> to %c100000 step %<span class="number">4</span> {</span><br><span class="line"> %<span class="number">5</span> = affine.min <span class="built_in">affine_map</span><(d0) -> (<span class="number">256</span>, -d0 + <span class="number">100000</span>)>(%arg0)</span><br><span class="line"> %<span class="number">6</span> = flow.dispatch.tensor.load %<span class="number">0</span>, offsets = [%arg0, <span class="number">0</span>], sizes = [%<span class="number">5</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<?x100xf32></span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [%arg0, <span class="number">0</span>], sizes = [%<span class="number">5</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<?x100xf32></span><br><span class="line"> %<span class="number">8</span> = flow.dispatch.tensor.load %<span class="number">2</span>, offsets = [%arg0], sizes = [%<span class="number">5</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>> -> tensor<?xf32></span><br><span class="line"> %<span class="number">9</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">6</span>, %<span class="number">7</span> : tensor<?x100xf32>, tensor<?x100xf32>) <span class="built_in">outs</span>(%<span class="number">8</span> : tensor<?xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[<span class="number">256</span>, <span class="number">4</span>]]>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">10</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> %<span class="number">11</span> = arith.addf %<span class="number">10</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">11</span> : f32</span><br><span class="line"> } -> tensor<?xf32></span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createWorkgroupSpecializationPass</p><p>将分块之后的计算逻辑分成固定形状和剩余部分动态形状两部分计算逻辑。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> ...</span><br><span class="line"> %workgroup_id_x = hal.interface.workgroup.id[<span class="number">0</span>] : index</span><br><span class="line"> %workgroup_count_x = hal.interface.workgroup.count[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">3</span> = affine.apply <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">256</span>)>()[%workgroup_id_x]</span><br><span class="line"> %<span class="number">4</span> = affine.apply <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">256</span>)>()[%workgroup_count_x]</span><br><span class="line"> scf.<span class="keyword">for</span> %arg0 = %<span class="number">3</span> to %c100000 step %<span class="number">4</span> {</span><br><span class="line"> %<span class="number">5</span> = affine.min <span class="built_in">affine_map</span><(d0) -> (<span class="number">256</span>, -d0 + <span class="number">100000</span>)>(%arg0)</span><br><span class="line"> %<span class="number">6</span> = flow.dispatch.tensor.load %<span class="number">0</span>, offsets = [%arg0, <span class="number">0</span>], sizes = [%<span class="number">5</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<?x100xf32></span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [%arg0, <span class="number">0</span>], sizes = [%<span class="number">5</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<?x100xf32></span><br><span class="line"> %<span class="number">8</span> = flow.dispatch.tensor.load %<span class="number">2</span>, offsets = [%arg0], sizes = [%<span class="number">5</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>> -> tensor<?xf32></span><br><span class="line"> %<span class="number">9</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">6</span>, %<span class="number">7</span> : tensor<?x100xf32>, tensor<?x100xf32>) <span class="built_in">outs</span>(%<span class="number">8</span> : tensor<?xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[<span class="number">256</span>, <span class="number">4</span>]]>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">10</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> %<span class="number">11</span> = arith.addf %<span class="number">10</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">11</span> : f32</span><br><span class="line"> } -> tensor<?xf32></span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>会转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> ...</span><br><span class="line"> %workgroup_id_x = hal.interface.workgroup.id[<span class="number">0</span>] : index</span><br><span class="line"> %workgroup_count_x = hal.interface.workgroup.count[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">3</span> = affine.apply <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">256</span>)>()[%workgroup_id_x]</span><br><span class="line"> %<span class="number">4</span> = affine.apply <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">256</span>)>()[%workgroup_count_x]</span><br><span class="line"> scf.<span class="keyword">for</span> %arg0 = %<span class="number">3</span> to %c100000 step %<span class="number">4</span> {</span><br><span class="line"> %<span class="number">5</span> = affine.min <span class="built_in">affine_map</span><(d0) -> (-d0 + <span class="number">100000</span>, <span class="number">256</span>)>(%arg0)</span><br><span class="line"> %c256 = arith.constant <span class="number">256</span> : index</span><br><span class="line"> %<span class="number">6</span> = arith.cmpi eq, %<span class="number">5</span>, %c256 : index</span><br><span class="line"> scf.<span class="keyword">if</span> %<span class="number">6</span> {</span><br><span class="line"> <span class="comment">// 计算[256,100]静态形状的分块</span></span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tensor.load %<span class="number">0</span>, offsets = [%arg0, <span class="number">0</span>], sizes = [%c256, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<?x100xf32></span><br><span class="line"> %<span class="number">8</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [%arg0, <span class="number">0</span>], sizes = [%c256, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<?x100xf32></span><br><span class="line"> %<span class="number">9</span> = flow.dispatch.tensor.load %<span class="number">2</span>, offsets = [%arg0], sizes = [%c256], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>> -> tensor<?xf32></span><br><span class="line"> %<span class="number">10</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">7</span>, %<span class="number">8</span> : tensor<?x100xf32>, tensor<?x100xf32>) <span class="built_in">outs</span>(%<span class="number">9</span> : tensor<?xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[<span class="number">256</span>, <span class="number">4</span>]]>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">11</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> %<span class="number">12</span> = arith.addf %<span class="number">11</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">12</span> : f32</span><br><span class="line"> } -> tensor<?xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">10</span>, %<span class="number">2</span>, offsets = [%arg0], sizes = [%c256], strides = [<span class="number">1</span>] : tensor<?xf32> -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>></span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> <span class="comment">// 计算剩下的[%5, 100]动态形状的分块</span></span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tensor.load %<span class="number">0</span>, offsets = [%arg0, <span class="number">0</span>], sizes = [%<span class="number">5</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<?x100xf32></span><br><span class="line"> %<span class="number">8</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [%arg0, <span class="number">0</span>], sizes = [%<span class="number">5</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<?x100xf32></span><br><span class="line"> %<span class="number">9</span> = flow.dispatch.tensor.load %<span class="number">2</span>, offsets = [%arg0], sizes = [%<span class="number">5</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>> -> tensor<?xf32></span><br><span class="line"> %<span class="number">10</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">7</span>, %<span class="number">8</span> : tensor<?x100xf32>, tensor<?x100xf32>) <span class="built_in">outs</span>(%<span class="number">9</span> : tensor<?xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[<span class="number">256</span>, <span class="number">4</span>]]>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">11</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> %<span class="number">12</span> = arith.addf %<span class="number">11</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">12</span> : f32</span><br><span class="line"> } -> tensor<?xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">10</span>, %<span class="number">2</span>, offsets = [%arg0], sizes = [%<span class="number">5</span>], strides = [<span class="number">1</span>] : tensor<?xf32> -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>></span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createRemoveSingleIterationLoopPass</p><p>移除确信只会循环1次的loop。比如上面的<code>scf.for %arg0 = %3 to %c100000 step %4</code>就只会被循环一次,因为step= 256 * 391 = 100096 >100000,因此这个循环会被消除,转换成如下代码。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> ...</span><br><span class="line"> %c256 = arith.constant <span class="number">256</span> : index</span><br><span class="line"> %workgroup_id_x = hal.interface.workgroup.id[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">3</span> = affine.apply <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">256</span>)>()[%workgroup_id_x]</span><br><span class="line"> %<span class="number">4</span> = affine.min <span class="built_in">affine_map</span><(d0) -> (-d0 + <span class="number">100000</span>, <span class="number">256</span>)>(%<span class="number">3</span>)</span><br><span class="line"> %<span class="number">5</span> = arith.cmpi eq, %<span class="number">4</span>, %c256 : index</span><br><span class="line"> scf.<span class="keyword">if</span> %<span class="number">5</span> {</span><br><span class="line"> %<span class="number">6</span> = flow.dispatch.tensor.load %<span class="number">0</span>, offsets = [%<span class="number">3</span>, <span class="number">0</span>], sizes = [<span class="number">256</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<<span class="number">256</span>x100xf32></span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [%<span class="number">3</span>, <span class="number">0</span>], sizes = [<span class="number">256</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<<span class="number">256</span>x100xf32></span><br><span class="line"> %<span class="number">8</span> = flow.dispatch.tensor.load %<span class="number">2</span>, offsets = [%<span class="number">3</span>], sizes = [<span class="number">256</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>> -> tensor<<span class="number">256</span>xf32></span><br><span class="line"> %<span class="number">9</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">6</span>, %<span class="number">7</span> : tensor<<span class="number">256</span>x100xf32>, tensor<<span class="number">256</span>x100xf32>) <span class="built_in">outs</span>(%<span class="number">8</span> : tensor<<span class="number">256</span>xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[<span class="number">256</span>, <span class="number">4</span>]]>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">10</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> %<span class="number">11</span> = arith.addf %<span class="number">10</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">11</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">256</span>xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">9</span>, %<span class="number">2</span>, offsets = [%<span class="number">3</span>], sizes = [<span class="number">256</span>], strides = [<span class="number">1</span>] : tensor<<span class="number">256</span>xf32> -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>></span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> %<span class="number">6</span> = flow.dispatch.tensor.load %<span class="number">0</span>, offsets = [%<span class="number">3</span>, <span class="number">0</span>], sizes = [%<span class="number">4</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<?x100xf32></span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [%<span class="number">3</span>, <span class="number">0</span>], sizes = [%<span class="number">4</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<?x100xf32></span><br><span class="line"> %<span class="number">8</span> = flow.dispatch.tensor.load %<span class="number">2</span>, offsets = [%<span class="number">3</span>], sizes = [%<span class="number">4</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>> -> tensor<?xf32></span><br><span class="line"> %<span class="number">9</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">6</span>, %<span class="number">7</span> : tensor<?x100xf32>, tensor<?x100xf32>) <span class="built_in">outs</span>(%<span class="number">8</span> : tensor<?xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[<span class="number">256</span>, <span class="number">4</span>]]>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">10</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> %<span class="number">11</span> = arith.addf %<span class="number">10</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">11</span> : f32</span><br><span class="line"> } -> tensor<?xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">9</span>, %<span class="number">2</span>, offsets = [%<span class="number">3</span>], sizes = [%<span class="number">4</span>], strides = [<span class="number">1</span>] : tensor<?xf32> -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>></span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createLLVMGPUTileTensor</p><p>前面pass主要针对的是外层parallelloop的vectorize,生成的是一个线程块的计算逻辑,接下来继续将负载分布到每一个线程,并且对内层的reduction也做vectorize。上面的代码继续转换成如下代码,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> %c100 = arith.constant <span class="number">100</span> : index</span><br><span class="line"> %c4 = arith.constant <span class="number">4</span> : index</span><br><span class="line"> %c64 = arith.constant <span class="number">64</span> : index</span><br><span class="line"> %c256 = arith.constant <span class="number">256</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> ...</span><br><span class="line"> %workgroup_id_x = hal.interface.workgroup.id[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">3</span> = affine.apply <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">256</span>)>()[%workgroup_id_x]</span><br><span class="line"> %<span class="number">4</span> = affine.min <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">-256</span> + <span class="number">100000</span>, <span class="number">256</span>)>()[%workgroup_id_x]</span><br><span class="line"> %<span class="number">5</span> = arith.cmpi eq, %<span class="number">4</span>, %c256 : index</span><br><span class="line"> scf.<span class="keyword">if</span> %<span class="number">5</span> {</span><br><span class="line"> %<span class="number">6</span> = flow.dispatch.tensor.load %<span class="number">0</span>, offsets = [%<span class="number">3</span>, <span class="number">0</span>], sizes = [<span class="number">256</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<<span class="number">256</span>x100xf32></span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [%<span class="number">3</span>, <span class="number">0</span>], sizes = [<span class="number">256</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<<span class="number">256</span>x100xf32></span><br><span class="line"> %<span class="number">8</span> = flow.dispatch.tensor.load %<span class="number">2</span>, offsets = [%<span class="number">3</span>], sizes = [<span class="number">256</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>> -> tensor<<span class="number">256</span>xf32></span><br><span class="line"> <span class="comment">// 64个线程并发计算,每个线程计算[4, 100]的分块</span></span><br><span class="line"> %<span class="number">9</span> = scf.foreach_thread (%arg0) <span class="built_in">in</span> (%c64) <span class="built_in">shared_outs</span>(%arg1 = %<span class="number">8</span>) -> (tensor<<span class="number">256</span>xf32>) {</span><br><span class="line"> %<span class="number">10</span> = affine.apply <span class="built_in">affine_map</span><(d0) -> (d0 * <span class="number">4</span>)>(%arg0)</span><br><span class="line"> %<span class="number">11</span> = affine.apply <span class="built_in">affine_map</span><(d0) -> (d0 * <span class="number">4</span>)>(%arg0)</span><br><span class="line"> %<span class="number">12</span> = affine.apply <span class="built_in">affine_map</span><(d0) -> (d0 * <span class="number">4</span>)>(%arg0)</span><br><span class="line"> %extracted_slice = tensor.extract_slice %<span class="number">6</span>[%<span class="number">10</span>, <span class="number">0</span>] [<span class="number">4</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<<span class="number">256</span>x100xf32> to tensor<<span class="number">4</span>x100xf32></span><br><span class="line"> %extracted_slice_0 = tensor.extract_slice %<span class="number">7</span>[%<span class="number">11</span>, <span class="number">0</span>] [<span class="number">4</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<<span class="number">256</span>x100xf32> to tensor<<span class="number">4</span>x100xf32></span><br><span class="line"> %extracted_slice_1 = tensor.extract_slice %arg1[%<span class="number">12</span>] [<span class="number">4</span>] [<span class="number">1</span>] : tensor<<span class="number">256</span>xf32> to tensor<<span class="number">4</span>xf32></span><br><span class="line"> <span class="comment">// 内层reduction loop的vectorize</span></span><br><span class="line"> %<span class="number">13</span> = scf.<span class="keyword">for</span> %arg2 = %c0 to %c100 step %c4 <span class="built_in">iter_args</span>(%arg3 = %extracted_slice_1) -> (tensor<<span class="number">4</span>xf32>) {</span><br><span class="line"> %extracted_slice_2 = tensor.extract_slice %extracted_slice[<span class="number">0</span>, %arg2] [<span class="number">4</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<<span class="number">4</span>x100xf32> to tensor<<span class="number">4</span>x4xf32></span><br><span class="line"> %extracted_slice_3 = tensor.extract_slice %extracted_slice_0[<span class="number">0</span>, %arg2] [<span class="number">4</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<<span class="number">4</span>x100xf32> to tensor<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">15</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%extracted_slice_2, %extracted_slice_3 : tensor<<span class="number">4</span>x4xf32>, tensor<<span class="number">4</span>x4xf32>) <span class="built_in">outs</span>(%arg3 : tensor<<span class="number">4</span>xf32>) attrs = {__internal_linalg_transform__ = <span class="string">"workgroup_k_tiled"</span>, lowering_config = #iree_codegen.lowering_config<tile_sizes = [[<span class="number">256</span>, <span class="number">4</span>]]>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_4: f32, %out: f32):</span><br><span class="line"> %<span class="number">16</span> = arith.addf %in, %in_4 : f32</span><br><span class="line"> %<span class="number">17</span> = arith.addf %<span class="number">16</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">17</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">4</span>xf32></span><br><span class="line"> scf.yield %<span class="number">15</span> : tensor<<span class="number">4</span>xf32></span><br><span class="line"> }</span><br><span class="line"> %<span class="number">14</span> = affine.apply <span class="built_in">affine_map</span><(d0) -> (d0 * <span class="number">4</span>)>(%arg0)</span><br><span class="line"> scf.foreach_thread.perform_concurrently {</span><br><span class="line"> tensor.parallel_insert_slice %<span class="number">13</span> into %arg1[%<span class="number">14</span>] [<span class="number">4</span>] [<span class="number">1</span>] : tensor<<span class="number">4</span>xf32> into tensor<<span class="number">256</span>xf32></span><br><span class="line"> }</span><br><span class="line"> } {mapping = [<span class="meta">#gpu.thread<span class="string"><x></span>]}</span></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">9</span>, %<span class="number">2</span>, offsets = [%<span class="number">3</span>], sizes = [<span class="number">256</span>], strides = [<span class="number">1</span>] : tensor<<span class="number">256</span>xf32> -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>></span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> %<span class="number">6</span> = flow.dispatch.tensor.load %<span class="number">0</span>, offsets = [%<span class="number">3</span>, <span class="number">0</span>], sizes = [%<span class="number">4</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<?x100xf32></span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [%<span class="number">3</span>, <span class="number">0</span>], sizes = [%<span class="number">4</span>, <span class="number">100</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>> -> tensor<?x100xf32></span><br><span class="line"> %<span class="number">8</span> = flow.dispatch.tensor.load %<span class="number">2</span>, offsets = [%<span class="number">3</span>], sizes = [%<span class="number">4</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>> -> tensor<?xf32></span><br><span class="line"> %dim = tensor.dim %<span class="number">6</span>, %c0 : tensor<?x100xf32></span><br><span class="line"> <span class="comment">// 64个线程并发计算,每个线程计算[%11, 100]的分块</span></span><br><span class="line"> %<span class="number">9</span> = scf.foreach_thread (%arg0) <span class="built_in">in</span> (%c64) <span class="built_in">shared_outs</span>(%arg1 = %<span class="number">8</span>) -> (tensor<?xf32>) {</span><br><span class="line"> %<span class="number">10</span> = affine.min <span class="built_in">affine_map</span><(d0)[s0] -> (-(d0 * (s0 ceildiv <span class="number">64</span>)) + s0, s0 ceildiv <span class="number">64</span>)>(%arg0)[%dim]</span><br><span class="line"> %<span class="number">11</span> = affine.max <span class="built_in">affine_map</span><(d0) -> (<span class="number">0</span>, d0)>(%<span class="number">10</span>)</span><br><span class="line"> %<span class="number">12</span> = affine.apply <span class="built_in">affine_map</span><(d0)[s0] -> (d0 * (s0 ceildiv <span class="number">64</span>))>(%arg0)[%dim]</span><br><span class="line"> %<span class="number">13</span> = affine.apply <span class="built_in">affine_map</span><(d0)[s0] -> (d0 * (s0 ceildiv <span class="number">64</span>))>(%arg0)[%dim]</span><br><span class="line"> %<span class="number">14</span> = affine.apply <span class="built_in">affine_map</span><(d0)[s0] -> (d0 * (s0 ceildiv <span class="number">64</span>))>(%arg0)[%dim]</span><br><span class="line"> %extracted_slice = tensor.extract_slice %<span class="number">6</span>[%<span class="number">12</span>, <span class="number">0</span>] [%<span class="number">11</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<?x100xf32> to tensor<?x100xf32></span><br><span class="line"> %extracted_slice_0 = tensor.extract_slice %<span class="number">7</span>[%<span class="number">13</span>, <span class="number">0</span>] [%<span class="number">11</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<?x100xf32> to tensor<?x100xf32></span><br><span class="line"> %extracted_slice_1 = tensor.extract_slice %arg1[%<span class="number">14</span>] [%<span class="number">11</span>] [<span class="number">1</span>] : tensor<?xf32> to tensor<?xf32></span><br><span class="line"> <span class="comment">// 内层reduction loop的vectorize</span></span><br><span class="line"> %<span class="number">15</span> = scf.<span class="keyword">for</span> %arg2 = %c0 to %c100 step %c4 <span class="built_in">iter_args</span>(%arg3 = %extracted_slice_1) -> (tensor<?xf32>) {</span><br><span class="line"> %extracted_slice_2 = tensor.extract_slice %extracted_slice[<span class="number">0</span>, %arg2] [%<span class="number">11</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<?x100xf32> to tensor<?x4xf32></span><br><span class="line"> %extracted_slice_3 = tensor.extract_slice %extracted_slice_0[<span class="number">0</span>, %arg2] [%<span class="number">11</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<?x100xf32> to tensor<?x4xf32></span><br><span class="line"> %extracted_slice_4 = tensor.extract_slice %arg3[<span class="number">0</span>] [%<span class="number">11</span>] [<span class="number">1</span>] : tensor<?xf32> to tensor<?xf32></span><br><span class="line"> %<span class="number">17</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%extracted_slice_2, %extracted_slice_3 : tensor<?x4xf32>, tensor<?x4xf32>) <span class="built_in">outs</span>(%extracted_slice_4 : tensor<?xf32>) attrs = {__internal_linalg_transform__ = <span class="string">"workgroup_k_tiled"</span>, lowering_config = #iree_codegen.lowering_config<tile_sizes = [[<span class="number">256</span>, <span class="number">4</span>]]>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_5: f32, %out: f32):</span><br><span class="line"> %<span class="number">18</span> = arith.addf %in, %in_5 : f32</span><br><span class="line"> %<span class="number">19</span> = arith.addf %<span class="number">18</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">19</span> : f32</span><br><span class="line"> } -> tensor<?xf32></span><br><span class="line"> %inserted_slice = tensor.insert_slice %<span class="number">17</span> into %arg3[<span class="number">0</span>] [%<span class="number">11</span>] [<span class="number">1</span>] : tensor<?xf32> into tensor<?xf32></span><br><span class="line"> scf.yield %inserted_slice : tensor<?xf32></span><br><span class="line"> }</span><br><span class="line"> %<span class="number">16</span> = affine.apply <span class="built_in">affine_map</span><(d0)[s0] -> (d0 * (s0 ceildiv <span class="number">64</span>))>(%arg0)[%dim]</span><br><span class="line"> scf.foreach_thread.perform_concurrently {</span><br><span class="line"> tensor.parallel_insert_slice %<span class="number">15</span> into %arg1[%<span class="number">16</span>] [%<span class="number">11</span>] [<span class="number">1</span>] : tensor<?xf32> into tensor<?xf32></span><br><span class="line"> }</span><br><span class="line"> } {mapping = [<span class="meta">#gpu.thread<span class="string"><x></span>]}</span></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">9</span>, %<span class="number">2</span>, offsets = [%<span class="number">3</span>], sizes = [%<span class="number">4</span>], strides = [<span class="number">1</span>] : tensor<?xf32> -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>></span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createRemoveSingleIterationLoopPass</p></li><li><p>createGPUVectorizationPass</p><p>将内层可被向量化的linalg op转换成vector op。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">11</span> = scf.<span class="keyword">for</span> %arg2 = %c0 to %c100 step %c4 <span class="built_in">iter_args</span>(%arg3 = %extracted_slice_1) -> (tensor<<span class="number">4</span>xf32>) {</span><br><span class="line"> %extracted_slice_2 = tensor.extract_slice %extracted_slice[<span class="number">0</span>, %arg2] [<span class="number">4</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<<span class="number">4</span>x100xf32> to tensor<<span class="number">4</span>x4xf32></span><br><span class="line"> %extracted_slice_3 = tensor.extract_slice %extracted_slice_0[<span class="number">0</span>, %arg2] [<span class="number">4</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<<span class="number">4</span>x100xf32> to tensor<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">12</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%extracted_slice_2, %extracted_slice_3 : tensor<<span class="number">4</span>x4xf32>, tensor<<span class="number">4</span>x4xf32>) <span class="built_in">outs</span>(%arg3 : tensor<<span class="number">4</span>xf32>) attrs = {__internal_linalg_transform__ = <span class="string">"workgroup_k_tiled"</span>, lowering_config = #iree_codegen.lowering_config<tile_sizes = [[<span class="number">256</span>, <span class="number">4</span>]]>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_4: f32, %out: f32):</span><br><span class="line"> %<span class="number">13</span> = arith.addf %in, %in_4 : f32</span><br><span class="line"> %<span class="number">14</span> = arith.addf %<span class="number">13</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">14</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">4</span>xf32></span><br><span class="line"> scf.yield %<span class="number">12</span> : tensor<<span class="number">4</span>xf32></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">11</span> = vector.transfer_read %extracted_slice_1[%c0], %cst {in_bounds = [<span class="literal">true</span>]} : tensor<<span class="number">4</span>xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line">%<span class="number">12</span> = scf.<span class="keyword">for</span> %arg2 = %c0 to %c100 step %c4 <span class="built_in">iter_args</span>(%arg3 = %<span class="number">11</span>) -> (vector<<span class="number">4</span>xf32>) {</span><br><span class="line"> %extracted_slice_2 = tensor.extract_slice %extracted_slice[<span class="number">0</span>, %arg2] [<span class="number">4</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<<span class="number">4</span>x100xf32> to tensor<<span class="number">4</span>x4xf32></span><br><span class="line"> %extracted_slice_3 = tensor.extract_slice %extracted_slice_0[<span class="number">0</span>, %arg2] [<span class="number">4</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<<span class="number">4</span>x100xf32> to tensor<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">14</span> = vector.transfer_read %extracted_slice_2[%c0, %c0], %cst {in_bounds = [<span class="literal">true</span>, <span class="literal">true</span>]} : tensor<<span class="number">4</span>x4xf32>, vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">15</span> = vector.transfer_read %extracted_slice_3[%c0, %c0], %cst {in_bounds = [<span class="literal">true</span>, <span class="literal">true</span>]} : tensor<<span class="number">4</span>x4xf32>, vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">16</span> = arith.addf %<span class="number">14</span>, %<span class="number">15</span> : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">17</span> = vector.multi_reduction <add>, %<span class="number">16</span>, %arg3 [<span class="number">1</span>] : vector<<span class="number">4</span>x4xf32> to vector<<span class="number">4</span>xf32></span><br><span class="line"> scf.yield %<span class="number">17</span> : vector<<span class="number">4</span>xf32></span><br><span class="line">}</span><br><span class="line">%<span class="number">13</span> = vector.transfer_write %<span class="number">12</span>, %extracted_slice_1[%c0] {in_bounds = [<span class="literal">true</span>]} : vector<<span class="number">4</span>xf32>, tensor<<span class="number">4</span>xf32></span><br></pre></td></tr></table></figure></li><li><p>addBufferizePasses</p><p>将tensor语义转换成memref语义。上面完整的sourcefunc代码会转换成如下代码:</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %c100 = arith.constant <span class="number">100</span> : index</span><br><span class="line"> %c4 = arith.constant <span class="number">4</span> : index</span><br><span class="line"> %c64 = arith.constant <span class="number">64</span> : index</span><br><span class="line"> %c256 = arith.constant <span class="number">256</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">0</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>x100xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> memref.assume_alignment %<span class="number">0</span>, <span class="number">64</span> : memref<<span class="number">100000</span>x100xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> %<span class="number">1</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">0</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>></span><br><span class="line"> %<span class="number">2</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">1</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>x100xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> memref.assume_alignment %<span class="number">2</span>, <span class="number">64</span> : memref<<span class="number">100000</span>x100xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> %<span class="number">3</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">1</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : !flow.dispatch.tensor<readonly:tensor<<span class="number">100000</span>x100xf32>></span><br><span class="line"> %<span class="number">4</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">2</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> memref.assume_alignment %<span class="number">4</span>, <span class="number">64</span> : memref<<span class="number">100000</span>xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> %<span class="number">5</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">2</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : !flow.dispatch.tensor<readwrite:tensor<<span class="number">100000</span>xf32>></span><br><span class="line"> %workgroup_id_x = hal.interface.workgroup.id[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">6</span> = affine.apply <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">256</span>)>()[%workgroup_id_x]</span><br><span class="line"> %<span class="number">7</span> = affine.min <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">-256</span> + <span class="number">100000</span>, <span class="number">256</span>)>()[%workgroup_id_x]</span><br><span class="line"> %<span class="number">8</span> = arith.cmpi eq, %<span class="number">7</span>, %c256 : index</span><br><span class="line"> scf.<span class="keyword">if</span> %<span class="number">8</span> {</span><br><span class="line"> %subview = memref.subview %<span class="number">0</span>[%<span class="number">6</span>, <span class="number">0</span>] [<span class="number">256</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<<span class="number">100000</span>x100xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>> to memref<span class="string"><256x100xf32, strided<[100, 1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> %subview_0 = memref.subview %<span class="number">2</span>[%<span class="number">6</span>, <span class="number">0</span>] [<span class="number">256</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<<span class="number">100000</span>x100xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>> to memref<span class="string"><256x100xf32, strided<[100, 1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> %subview_1 = memref.subview %<span class="number">4</span>[%<span class="number">6</span>] [<span class="number">256</span>] [<span class="number">1</span>] : memref<<span class="number">100000</span>xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>> to memref<span class="string"><256xf32, strided<[1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> scf.foreach_thread (%arg0) <span class="built_in">in</span> (%c64) {</span><br><span class="line"> %<span class="number">9</span> = affine.apply <span class="built_in">affine_map</span><(d0) -> (d0 * <span class="number">4</span>)>(%arg0)</span><br><span class="line"> %subview_2 = memref.subview %subview_1[%<span class="number">9</span>] [<span class="number">4</span>] [<span class="number">1</span>] : memref<<span class="number">256</span>xf32, strided<[<span class="number">1</span>], offset: ?>, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>> to memref<span class="string"><4xf32, strided<[1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> %<span class="number">10</span> = vector.transfer_read %subview_1[%<span class="number">9</span>], %cst {in_bounds = [<span class="literal">true</span>]} : memref<<span class="number">256</span>xf32, strided<[<span class="number">1</span>], offset: ?>, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>>, vector<span class="string"><4xf32></span></span></span><br><span class="line"> %<span class="number">11</span> = scf.<span class="keyword">for</span> %arg1 = %c0 to %c100 step %c4 <span class="built_in">iter_args</span>(%arg2 = %<span class="number">10</span>) -> (vector<<span class="number">4</span>xf32>) {</span><br><span class="line"> %<span class="number">12</span> = vector.transfer_read %subview[%<span class="number">9</span>, %arg1], %cst {in_bounds = [<span class="literal">true</span>, <span class="literal">true</span>]} : memref<<span class="number">256</span>x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>>, vector<span class="string"><4x4xf32></span></span></span><br><span class="line"> %<span class="number">13</span> = vector.transfer_read %subview_0[%<span class="number">9</span>, %arg1], %cst {in_bounds = [<span class="literal">true</span>, <span class="literal">true</span>]} : memref<<span class="number">256</span>x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>>, vector<span class="string"><4x4xf32></span></span></span><br><span class="line"> %<span class="number">14</span> = arith.addf %<span class="number">12</span>, %<span class="number">13</span> : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">15</span> = vector.multi_reduction <add>, %<span class="number">14</span>, %arg2 [<span class="number">1</span>] : vector<<span class="number">4</span>x4xf32> to vector<<span class="number">4</span>xf32></span><br><span class="line"> scf.yield %<span class="number">15</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> }</span><br><span class="line"> vector.transfer_write %<span class="number">11</span>, %subview_2[%c0] {in_bounds = [<span class="literal">true</span>]} : vector<<span class="number">4</span>xf32>, memref<<span class="number">4</span>xf32, strided<[<span class="number">1</span>], offset: ?>, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> } {mapping = [<span class="meta">#gpu.thread<span class="string"><x></span>]}</span></span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> %subview = memref.subview %<span class="number">0</span>[%<span class="number">6</span>, <span class="number">0</span>] [%<span class="number">7</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<<span class="number">100000</span>x100xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>> to memref<span class="string"><?x100xf32, strided<[100, 1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> %subview_0 = memref.subview %<span class="number">2</span>[%<span class="number">6</span>, <span class="number">0</span>] [%<span class="number">7</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<<span class="number">100000</span>x100xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>> to memref<span class="string"><?x100xf32, strided<[100, 1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> %subview_1 = memref.subview %<span class="number">4</span>[%<span class="number">6</span>] [%<span class="number">7</span>] [<span class="number">1</span>] : memref<<span class="number">100000</span>xf32, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>> to memref<span class="string"><?xf32, strided<[1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> scf.foreach_thread (%arg0) <span class="built_in">in</span> (%c64) {</span><br><span class="line"> %<span class="number">9</span> = affine.min <span class="built_in">affine_map</span><(d0)[s0] -> (-(d0 * (s0 ceildiv <span class="number">64</span>)) + s0, s0 ceildiv <span class="number">64</span>)>(%arg0)[%<span class="number">7</span>]</span><br><span class="line"> %<span class="number">10</span> = affine.max <span class="built_in">affine_map</span><(d0) -> (<span class="number">0</span>, d0)>(%<span class="number">9</span>)</span><br><span class="line"> %<span class="number">11</span> = affine.apply <span class="built_in">affine_map</span><(d0)[s0] -> (d0 * (s0 ceildiv <span class="number">64</span>))>(%arg0)[%<span class="number">7</span>]</span><br><span class="line"> %subview_2 = memref.subview %subview[%<span class="number">11</span>, <span class="number">0</span>] [%<span class="number">10</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<?x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>> to memref<span class="string"><?x100xf32, strided<[100, 1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> %subview_3 = memref.subview %subview_0[%<span class="number">11</span>, <span class="number">0</span>] [%<span class="number">10</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<?x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>> to memref<span class="string"><?x100xf32, strided<[100, 1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> %subview_4 = memref.subview %subview_1[%<span class="number">11</span>] [%<span class="number">10</span>] [<span class="number">1</span>] : memref<?xf32, strided<[<span class="number">1</span>], offset: ?>, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>> to memref<span class="string"><?xf32, strided<[1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> scf.<span class="keyword">for</span> %arg1 = %c0 to %c100 step %c4 {</span><br><span class="line"> %subview_5 = memref.subview %subview_2[<span class="number">0</span>, %arg1] [%<span class="number">10</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<?x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>> to memref<span class="string"><?x4xf32, strided<[100, 1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> %subview_6 = memref.subview %subview_3[<span class="number">0</span>, %arg1] [%<span class="number">10</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<?x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>> to memref<span class="string"><?x4xf32, strided<[100, 1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>></span></span><br><span class="line"> linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%subview_5, %subview_6 : memref<?x4xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>, <span class="meta">#hal.descriptor_type<span class="string"><storage_buffer></span>>, memref<span class="string"><?x4xf32, strided<[100, 1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>>) outs(%subview_4 : memref<span class="string"><?xf32, strided<[1], offset: ?></span>, #hal.descriptor_type<span class="string"><storage_buffer></span>>) attrs = {__internal_linalg_transform__ = <span class="string">"workgroup_k_tiled"</span>, lowering_config = #iree_codegen.lowering_config<span class="string"><tile_sizes = [[256, 4]]></span>} {</span></span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_7: f32, %out: f32):</span><br><span class="line"> %<span class="number">12</span> = arith.addf %in, %in_7 : f32</span><br><span class="line"> %<span class="number">13</span> = arith.addf %<span class="number">12</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">13</span> : f32</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> } {mapping = [<span class="meta">#gpu.thread<span class="string"><x></span>]}</span></span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createLLVMGPUDistribute</p><p>将任务分配到每一个线程,sourcefunc从线程块的计算逻辑转换成每个线程的计算逻辑,即用gpu.thread_id(x, y,z)替换scf.foreach_thread。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %c100 = arith.constant <span class="number">100</span> : index</span><br><span class="line"> %c4 = arith.constant <span class="number">4</span> : index</span><br><span class="line"> %c64 = arith.constant <span class="number">64</span> : index</span><br><span class="line"> %c256 = arith.constant <span class="number">256</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">0</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> memref.assume_alignment %<span class="number">0</span>, <span class="number">64</span> : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> %<span class="number">1</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">1</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> memref.assume_alignment %<span class="number">1</span>, <span class="number">64</span> : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> %<span class="number">2</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">2</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>xf32></span><br><span class="line"> memref.assume_alignment %<span class="number">2</span>, <span class="number">64</span> : memref<<span class="number">100000</span>xf32></span><br><span class="line"> %workgroup_id_x = hal.interface.workgroup.id[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">3</span> = affine.apply <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">256</span>)>()[%workgroup_id_x]</span><br><span class="line"> %<span class="number">4</span> = affine.min <span class="built_in">affine_map</span><()[s0] -> (s0 * <span class="number">-256</span> + <span class="number">100000</span>, <span class="number">256</span>)>()[%workgroup_id_x]</span><br><span class="line"> %<span class="number">5</span> = arith.cmpi eq, %<span class="number">4</span>, %c256 : index</span><br><span class="line"> scf.<span class="keyword">if</span> %<span class="number">5</span> {</span><br><span class="line"> %subview = memref.subview %<span class="number">0</span>[%<span class="number">3</span>, <span class="number">0</span>] [<span class="number">256</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<<span class="number">100000</span>x100xf32> to memref<<span class="number">256</span>x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>></span><br><span class="line"> %subview_0 = memref.subview %<span class="number">1</span>[%<span class="number">3</span>, <span class="number">0</span>] [<span class="number">256</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<<span class="number">100000</span>x100xf32> to memref<<span class="number">256</span>x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>></span><br><span class="line"> %subview_1 = memref.subview %<span class="number">2</span>[%<span class="number">3</span>] [<span class="number">256</span>] [<span class="number">1</span>] : memref<<span class="number">100000</span>xf32> to memref<<span class="number">256</span>xf32, strided<[<span class="number">1</span>], offset: ?>></span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %<span class="number">6</span> = gpu.thread_id x</span><br><span class="line"> %<span class="number">7</span> = gpu.thread_id y</span><br><span class="line"> %<span class="number">8</span> = gpu.thread_id z</span><br><span class="line"> %<span class="number">9</span> = affine.apply <span class="built_in">affine_map</span><(d0) -> (d0 * <span class="number">4</span>)>(%<span class="number">6</span>)</span><br><span class="line"> %subview_2 = memref.subview %subview_1[%<span class="number">9</span>] [<span class="number">4</span>] [<span class="number">1</span>] : memref<<span class="number">256</span>xf32, strided<[<span class="number">1</span>], offset: ?>> to memref<<span class="number">4</span>xf32, strided<[<span class="number">1</span>], offset: ?>></span><br><span class="line"> %<span class="number">10</span> = vector.transfer_read %subview_1[%<span class="number">9</span>], %cst {in_bounds = [<span class="literal">true</span>]} : memref<<span class="number">256</span>xf32, strided<[<span class="number">1</span>], offset: ?>>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">11</span> = scf.<span class="keyword">for</span> %arg0 = %c0 to %c100 step %c4 <span class="built_in">iter_args</span>(%arg1 = %<span class="number">10</span>) -> (vector<<span class="number">4</span>xf32>) {</span><br><span class="line"> %<span class="number">12</span> = vector.transfer_read %subview[%<span class="number">9</span>, %arg0], %cst {in_bounds = [<span class="literal">true</span>, <span class="literal">true</span>]} : memref<<span class="number">256</span>x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>>, vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">13</span> = vector.transfer_read %subview_0[%<span class="number">9</span>, %arg0], %cst {in_bounds = [<span class="literal">true</span>, <span class="literal">true</span>]} : memref<<span class="number">256</span>x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>>, vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">14</span> = arith.addf %<span class="number">12</span>, %<span class="number">13</span> : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">15</span> = vector.multi_reduction <add>, %<span class="number">14</span>, %arg1 [<span class="number">1</span>] : vector<<span class="number">4</span>x4xf32> to vector<<span class="number">4</span>xf32></span><br><span class="line"> scf.yield %<span class="number">15</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> }</span><br><span class="line"> vector.transfer_write %<span class="number">11</span>, %subview_2[%c0] {in_bounds = [<span class="literal">true</span>]} : vector<<span class="number">4</span>xf32>, memref<<span class="number">4</span>xf32, strided<[<span class="number">1</span>], offset: ?>></span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> %subview = memref.subview %<span class="number">0</span>[%<span class="number">3</span>, <span class="number">0</span>] [%<span class="number">4</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<<span class="number">100000</span>x100xf32> to memref<?x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>></span><br><span class="line"> %subview_0 = memref.subview %<span class="number">1</span>[%<span class="number">3</span>, <span class="number">0</span>] [%<span class="number">4</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<<span class="number">100000</span>x100xf32> to memref<?x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>></span><br><span class="line"> %subview_1 = memref.subview %<span class="number">2</span>[%<span class="number">3</span>] [%<span class="number">4</span>] [<span class="number">1</span>] : memref<<span class="number">100000</span>xf32> to memref<?xf32, strided<[<span class="number">1</span>], offset: ?>></span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %<span class="number">6</span> = gpu.thread_id x</span><br><span class="line"> %<span class="number">7</span> = gpu.thread_id y</span><br><span class="line"> %<span class="number">8</span> = gpu.thread_id z</span><br><span class="line"> %<span class="number">9</span> = affine.min <span class="built_in">affine_map</span><(d0)[s0] -> (-(d0 * (s0 ceildiv <span class="number">64</span>)) + s0, s0 ceildiv <span class="number">64</span>)>(%<span class="number">6</span>)[%<span class="number">4</span>]</span><br><span class="line"> %<span class="number">10</span> = affine.max <span class="built_in">affine_map</span><(d0) -> (<span class="number">0</span>, d0)>(%<span class="number">9</span>)</span><br><span class="line"> %<span class="number">11</span> = affine.apply <span class="built_in">affine_map</span><(d0)[s0] -> (d0 * (s0 ceildiv <span class="number">64</span>))>(%<span class="number">6</span>)[%<span class="number">4</span>]</span><br><span class="line"> %subview_2 = memref.subview %subview[%<span class="number">11</span>, <span class="number">0</span>] [%<span class="number">10</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<?x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>> to memref<?x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>></span><br><span class="line"> %subview_3 = memref.subview %subview_0[%<span class="number">11</span>, <span class="number">0</span>] [%<span class="number">10</span>, <span class="number">100</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<?x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>> to memref<?x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>></span><br><span class="line"> %subview_4 = memref.subview %subview_1[%<span class="number">11</span>] [%<span class="number">10</span>] [<span class="number">1</span>] : memref<?xf32, strided<[<span class="number">1</span>], offset: ?>> to memref<?xf32, strided<[<span class="number">1</span>], offset: ?>></span><br><span class="line"> scf.<span class="keyword">for</span> %arg0 = %c0 to %c100 step %c4 {</span><br><span class="line"> %subview_5 = memref.subview %subview_2[<span class="number">0</span>, %arg0] [%<span class="number">10</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<?x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>> to memref<?x4xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>></span><br><span class="line"> %subview_6 = memref.subview %subview_3[<span class="number">0</span>, %arg0] [%<span class="number">10</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<?x100xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>> to memref<?x4xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>></span><br><span class="line"> linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%subview_5, %subview_6 : memref<?x4xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>>, memref<?x4xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>>) <span class="built_in">outs</span>(%subview_4 : memref<?xf32, strided<[<span class="number">1</span>], offset: ?>>) attrs = {__internal_linalg_transform__ = <span class="string">"workgroup_k_tiled"</span>, lowering_config = #iree_codegen.lowering_config<tile_sizes = [[<span class="number">256</span>, <span class="number">4</span>]]>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_7: f32, %out: f32):</span><br><span class="line"> %<span class="number">12</span> = arith.addf %in, %in_7 : f32</span><br><span class="line"> %<span class="number">13</span> = arith.addf %<span class="number">12</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">13</span> : f32</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createLoopInvariantCodeMotionPass</p></li><li><p>memref::createFoldMemRefAliasOpsPass</p></li><li><p>createOptimizeVectorTransferPass</p></li></ul></li><li><p>GPUMatmulSimtPassPipeline</p></li><li><p>GPUMatmulTensorCorePassPipeline</p></li><li><p>GPUTransposePassPipeline</p></li><li><p>GPUWarpReductionPassPipeline</p></li><li><p>GPUTransformDialectPasses</p></li></ul></li></ul></li><li><p>addLowerToLLVMGPUPasses</p><p>继续将device代码递降到affine和gpu dialect,最终转换到NVVM IR或ROCDLIR。</p><ul><li><p>IREE::LinalgExt::createLinalgExtToLoopsPass</p><p>将LinalgExt op转换成loops。</p></li><li><p>createMemrefCopyToLinalgPass</p><p>将<code>memref.copy</code>转换成linalg generic op。</p></li><li><p>createConvertLinalgToLoopsPass</p><p>将linalg generic op转换成loops。</p></li><li><p>createPadDynamicAlloc</p><p>以pad的方式申请动态大小的内存。比如需要申请的内存大小和dim相关,<code>%dim = affine_max(0, %src)</code>,那么这里就会以<code>%dim = %src</code>的最大size来申请内存。</p></li><li><p>createLowerAffinePass</p><p>将affine op(比如<code>affine.for</code>, <code>affine.if</code> and<code>affine.apply</code>等) 递降成更低层的arith、memref和scfop。上面完整的source func代码会转换成如下代码,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> %c<span class="number">-1</span> = arith.constant <span class="number">-1</span> : index</span><br><span class="line"> %c64 = arith.constant <span class="number">64</span> : index</span><br><span class="line"> %c100000 = arith.constant <span class="number">100000</span> : index</span><br><span class="line"> %c<span class="number">-256</span> = arith.constant <span class="number">-256</span> : index</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %c100 = arith.constant <span class="number">100</span> : index</span><br><span class="line"> %c4 = arith.constant <span class="number">4</span> : index</span><br><span class="line"> %c256 = arith.constant <span class="number">256</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">0</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> memref.assume_alignment %<span class="number">0</span>, <span class="number">64</span> : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> %<span class="number">1</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">1</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> memref.assume_alignment %<span class="number">1</span>, <span class="number">64</span> : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> %<span class="number">2</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">2</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>xf32></span><br><span class="line"> memref.assume_alignment %<span class="number">2</span>, <span class="number">64</span> : memref<<span class="number">100000</span>xf32></span><br><span class="line"> %workgroup_id_x = hal.interface.workgroup.id[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">3</span> = arith.muli %workgroup_id_x, %c<span class="number">-256</span> : index</span><br><span class="line"> %<span class="number">4</span> = arith.addi %<span class="number">3</span>, %c100000 : index</span><br><span class="line"> %<span class="number">5</span> = arith.cmpi slt, %<span class="number">4</span>, %c256 : index</span><br><span class="line"> %<span class="number">6</span> = arith.select %<span class="number">5</span>, %<span class="number">4</span>, %c256 : index</span><br><span class="line"> %<span class="number">7</span> = arith.cmpi eq, %<span class="number">6</span>, %c256 : index</span><br><span class="line"> scf.<span class="keyword">if</span> %<span class="number">7</span> {</span><br><span class="line"> %<span class="number">8</span> = gpu.thread_id x</span><br><span class="line"> %<span class="number">9</span> = arith.muli %<span class="number">8</span>, %c4 : index</span><br><span class="line"> %<span class="number">10</span> = arith.muli %workgroup_id_x, %c256 : index</span><br><span class="line"> %<span class="number">11</span> = arith.addi %<span class="number">9</span>, %<span class="number">10</span> : index</span><br><span class="line"> %<span class="number">12</span> = vector.transfer_read %<span class="number">2</span>[%<span class="number">11</span>], %cst {in_bounds = [<span class="literal">true</span>]} : memref<<span class="number">100000</span>xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">13</span> = scf.<span class="keyword">for</span> %arg0 = %c0 to %c100 step %c4 <span class="built_in">iter_args</span>(%arg1 = %<span class="number">12</span>) -> (vector<<span class="number">4</span>xf32>) {</span><br><span class="line"> %<span class="number">14</span> = vector.transfer_read %<span class="number">0</span>[%<span class="number">11</span>, %arg0], %cst {in_bounds = [<span class="literal">true</span>, <span class="literal">true</span>]} : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">15</span> = vector.transfer_read %<span class="number">1</span>[%<span class="number">11</span>, %arg0], %cst {in_bounds = [<span class="literal">true</span>, <span class="literal">true</span>]} : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">16</span> = arith.addf %<span class="number">14</span>, %<span class="number">15</span> : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">17</span> = vector.multi_reduction <add>, %<span class="number">16</span>, %arg1 [<span class="number">1</span>] : vector<<span class="number">4</span>x4xf32> to vector<<span class="number">4</span>xf32></span><br><span class="line"> scf.yield %<span class="number">17</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> }</span><br><span class="line"> vector.transfer_write %<span class="number">13</span>, %<span class="number">2</span>[%<span class="number">11</span>] {in_bounds = [<span class="literal">true</span>]} : vector<<span class="number">4</span>xf32>, memref<<span class="number">100000</span>xf32></span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> %<span class="number">8</span> = gpu.thread_id x</span><br><span class="line"> %<span class="number">9</span> = arith.cmpi sle, %<span class="number">6</span>, %c0 : index</span><br><span class="line"> %<span class="number">10</span> = arith.subi %c0, %<span class="number">6</span> : index</span><br><span class="line"> %<span class="number">11</span> = arith.subi %<span class="number">6</span>, %c1 : index</span><br><span class="line"> %<span class="number">12</span> = arith.select %<span class="number">9</span>, %<span class="number">10</span>, %<span class="number">11</span> : index</span><br><span class="line"> %<span class="number">13</span> = arith.divsi %<span class="number">12</span>, %c64 : index</span><br><span class="line"> %<span class="number">14</span> = arith.subi %c0, %<span class="number">13</span> : index</span><br><span class="line"> %<span class="number">15</span> = arith.addi %<span class="number">13</span>, %c1 : index</span><br><span class="line"> %<span class="number">16</span> = arith.select %<span class="number">9</span>, %<span class="number">14</span>, %<span class="number">15</span> : index</span><br><span class="line"> %<span class="number">17</span> = arith.muli %<span class="number">8</span>, %<span class="number">16</span> : index</span><br><span class="line"> %<span class="number">18</span> = arith.muli %<span class="number">17</span>, %c<span class="number">-1</span> : index</span><br><span class="line"> %<span class="number">19</span> = arith.addi %<span class="number">18</span>, %<span class="number">6</span> : index</span><br><span class="line"> %<span class="number">20</span> = arith.cmpi slt, %<span class="number">19</span>, %<span class="number">16</span> : index</span><br><span class="line"> %<span class="number">21</span> = arith.select %<span class="number">20</span>, %<span class="number">19</span>, %<span class="number">16</span> : index</span><br><span class="line"> %<span class="number">22</span> = arith.cmpi slt, %<span class="number">21</span>, %c0 : index</span><br><span class="line"> %<span class="number">23</span> = arith.select %<span class="number">22</span>, %c0, %<span class="number">21</span> : index</span><br><span class="line"> %<span class="number">24</span> = arith.muli %workgroup_id_x, %c256 : index</span><br><span class="line"> %<span class="number">25</span> = arith.addi %<span class="number">17</span>, %<span class="number">24</span> : index</span><br><span class="line"> %subview = memref.subview %<span class="number">2</span>[%<span class="number">25</span>] [%<span class="number">23</span>] [<span class="number">1</span>] : memref<<span class="number">100000</span>xf32> to memref<?xf32, strided<[<span class="number">1</span>], offset: ?>></span><br><span class="line"> scf.<span class="keyword">for</span> %arg0 = %c0 to %c100 step %c4 {</span><br><span class="line"> %subview_0 = memref.subview %<span class="number">0</span>[%<span class="number">25</span>, %arg0] [%<span class="number">23</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<<span class="number">100000</span>x100xf32> to memref<?x4xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>></span><br><span class="line"> %subview_1 = memref.subview %<span class="number">1</span>[%<span class="number">25</span>, %arg0] [%<span class="number">23</span>, <span class="number">4</span>] [<span class="number">1</span>, <span class="number">1</span>] : memref<<span class="number">100000</span>x100xf32> to memref<?x4xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>></span><br><span class="line"> scf.<span class="keyword">for</span> %arg1 = %c0 to %<span class="number">23</span> step %c1 {</span><br><span class="line"> scf.<span class="keyword">for</span> %arg2 = %c0 to %c4 step %c1 {</span><br><span class="line"> %<span class="number">26</span> = memref.load %subview_0[%arg1, %arg2] : memref<?x4xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>></span><br><span class="line"> %<span class="number">27</span> = memref.load %subview_1[%arg1, %arg2] : memref<?x4xf32, strided<[<span class="number">100</span>, <span class="number">1</span>], offset: ?>></span><br><span class="line"> %<span class="number">28</span> = memref.load %subview[%arg1] : memref<?xf32, strided<[<span class="number">1</span>], offset: ?>></span><br><span class="line"> %<span class="number">29</span> = arith.addf %<span class="number">26</span>, %<span class="number">27</span> : f32</span><br><span class="line"> %<span class="number">30</span> = arith.addf %<span class="number">29</span>, %<span class="number">28</span> : f32</span><br><span class="line"> memref.store %<span class="number">30</span>, %subview[%arg1] : memref<?xf32, strided<[<span class="number">1</span>], offset: ?>></span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>arith::createConstantBufferizePass</p></li><li><p>createFoldTensorExtractOpPass</p></li><li><p>createLLVMGPUVectorLoweringPass</p><p>将多维vector op展开成一维的vector op。上面完整的sourcefunc代码会转换成如下代码,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> %cst = arith.constant dense<<span class="number">0.000000e+00</span>> : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %c<span class="number">-1</span> = arith.constant <span class="number">-1</span> : index</span><br><span class="line"> %c64 = arith.constant <span class="number">64</span> : index</span><br><span class="line"> %c100000 = arith.constant <span class="number">100000</span> : index</span><br><span class="line"> %c<span class="number">-256</span> = arith.constant <span class="number">-256</span> : index</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c100 = arith.constant <span class="number">100</span> : index</span><br><span class="line"> %c4 = arith.constant <span class="number">4</span> : index</span><br><span class="line"> %c256 = arith.constant <span class="number">256</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">0</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> memref.assume_alignment %<span class="number">0</span>, <span class="number">64</span> : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> %<span class="number">1</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">1</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> memref.assume_alignment %<span class="number">1</span>, <span class="number">64</span> : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> %<span class="number">2</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">2</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>xf32></span><br><span class="line"> memref.assume_alignment %<span class="number">2</span>, <span class="number">64</span> : memref<<span class="number">100000</span>xf32></span><br><span class="line"> %workgroup_id_x = hal.interface.workgroup.id[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">3</span> = arith.muli %workgroup_id_x, %c<span class="number">-256</span> : index</span><br><span class="line"> %<span class="number">4</span> = arith.addi %<span class="number">3</span>, %c100000 : index</span><br><span class="line"> %<span class="number">5</span> = arith.cmpi slt, %<span class="number">4</span>, %c256 : index</span><br><span class="line"> %<span class="number">6</span> = arith.select %<span class="number">5</span>, %<span class="number">4</span>, %c256 : index</span><br><span class="line"> %<span class="number">7</span> = arith.cmpi eq, %<span class="number">6</span>, %c256 : index</span><br><span class="line"> scf.<span class="keyword">if</span> %<span class="number">7</span> {</span><br><span class="line"> %<span class="number">8</span> = gpu.thread_id x</span><br><span class="line"> %<span class="number">9</span> = arith.muli %<span class="number">8</span>, %c4 : index</span><br><span class="line"> %<span class="number">10</span> = arith.muli %workgroup_id_x, %c256 : index</span><br><span class="line"> %<span class="number">11</span> = arith.addi %<span class="number">9</span>, %<span class="number">10</span> : index</span><br><span class="line"> %<span class="number">12</span> = vector.load %<span class="number">2</span>[%<span class="number">11</span>] : memref<<span class="number">100000</span>xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">13</span> = scf.<span class="keyword">for</span> %arg0 = %c0 to %c100 step %c4 <span class="built_in">iter_args</span>(%arg1 = %<span class="number">12</span>) -> (vector<<span class="number">4</span>xf32>) {</span><br><span class="line"> %<span class="number">14</span> = vector.load %<span class="number">0</span>[%<span class="number">11</span>, %arg0] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">15</span> = vector.insert %<span class="number">14</span>, %cst [<span class="number">0</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">16</span> = affine.apply <span class="built_in">affine_map</span><(d0) -> (d0 + <span class="number">1</span>)>(%<span class="number">11</span>)</span><br><span class="line"> %<span class="number">17</span> = vector.load %<span class="number">0</span>[%<span class="number">16</span>, %arg0] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">18</span> = vector.insert %<span class="number">17</span>, %<span class="number">15</span> [<span class="number">1</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">19</span> = affine.apply <span class="built_in">affine_map</span><(d0) -> (d0 + <span class="number">2</span>)>(%<span class="number">11</span>)</span><br><span class="line"> %<span class="number">20</span> = vector.load %<span class="number">0</span>[%<span class="number">19</span>, %arg0] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">21</span> = vector.insert %<span class="number">20</span>, %<span class="number">18</span> [<span class="number">2</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">22</span> = affine.apply <span class="built_in">affine_map</span><(d0) -> (d0 + <span class="number">3</span>)>(%<span class="number">11</span>)</span><br><span class="line"> %<span class="number">23</span> = vector.load %<span class="number">0</span>[%<span class="number">22</span>, %arg0] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">24</span> = vector.insert %<span class="number">23</span>, %<span class="number">21</span> [<span class="number">3</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">25</span> = vector.load %<span class="number">1</span>[%<span class="number">11</span>, %arg0] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">26</span> = vector.insert %<span class="number">25</span>, %cst [<span class="number">0</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">27</span> = affine.apply <span class="built_in">affine_map</span><(d0) -> (d0 + <span class="number">1</span>)>(%<span class="number">11</span>)</span><br><span class="line"> %<span class="number">28</span> = vector.load %<span class="number">1</span>[%<span class="number">27</span>, %arg0] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">29</span> = vector.insert %<span class="number">28</span>, %<span class="number">26</span> [<span class="number">1</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">30</span> = affine.apply <span class="built_in">affine_map</span><(d0) -> (d0 + <span class="number">2</span>)>(%<span class="number">11</span>)</span><br><span class="line"> %<span class="number">31</span> = vector.load %<span class="number">1</span>[%<span class="number">30</span>, %arg0] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">32</span> = vector.insert %<span class="number">31</span>, %<span class="number">29</span> [<span class="number">2</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">33</span> = affine.apply <span class="built_in">affine_map</span><(d0) -> (d0 + <span class="number">3</span>)>(%<span class="number">11</span>)</span><br><span class="line"> %<span class="number">34</span> = vector.load %<span class="number">1</span>[%<span class="number">33</span>, %arg0] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">35</span> = vector.insert %<span class="number">34</span>, %<span class="number">32</span> [<span class="number">3</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">36</span> = arith.addf %<span class="number">24</span>, %<span class="number">35</span> : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">37</span> = vector.transpose %<span class="number">36</span>, [<span class="number">1</span>, <span class="number">0</span>] : vector<<span class="number">4</span>x4xf32> to vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">38</span> = vector.extract %<span class="number">37</span>[<span class="number">0</span>] : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">39</span> = arith.addf %<span class="number">38</span>, %arg1 : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">40</span> = vector.extract %<span class="number">37</span>[<span class="number">1</span>] : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">41</span> = arith.addf %<span class="number">40</span>, %<span class="number">39</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">42</span> = vector.extract %<span class="number">37</span>[<span class="number">2</span>] : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">43</span> = arith.addf %<span class="number">42</span>, %<span class="number">41</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">44</span> = vector.extract %<span class="number">37</span>[<span class="number">3</span>] : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">45</span> = arith.addf %<span class="number">44</span>, %<span class="number">43</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> scf.yield %<span class="number">45</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> }</span><br><span class="line"> vector.store %<span class="number">13</span>, %<span class="number">2</span>[%<span class="number">11</span>] : memref<<span class="number">100000</span>xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> %<span class="number">8</span> = gpu.thread_id x</span><br><span class="line"> %<span class="number">9</span> = arith.cmpi sle, %<span class="number">6</span>, %c0 : index</span><br><span class="line"> %<span class="number">10</span> = arith.subi %c0, %<span class="number">6</span> : index</span><br><span class="line"> %<span class="number">11</span> = arith.subi %<span class="number">6</span>, %c1 : index</span><br><span class="line"> %<span class="number">12</span> = arith.select %<span class="number">9</span>, %<span class="number">10</span>, %<span class="number">11</span> : index</span><br><span class="line"> %<span class="number">13</span> = arith.divsi %<span class="number">12</span>, %c64 : index</span><br><span class="line"> %<span class="number">14</span> = arith.subi %c0, %<span class="number">13</span> : index</span><br><span class="line"> %<span class="number">15</span> = arith.addi %<span class="number">13</span>, %c1 : index</span><br><span class="line"> %<span class="number">16</span> = arith.select %<span class="number">9</span>, %<span class="number">14</span>, %<span class="number">15</span> : index</span><br><span class="line"> %<span class="number">17</span> = arith.muli %<span class="number">8</span>, %<span class="number">16</span> : index</span><br><span class="line"> %<span class="number">18</span> = arith.muli %<span class="number">17</span>, %c<span class="number">-1</span> : index</span><br><span class="line"> %<span class="number">19</span> = arith.addi %<span class="number">18</span>, %<span class="number">6</span> : index</span><br><span class="line"> %<span class="number">20</span> = arith.cmpi slt, %<span class="number">19</span>, %<span class="number">16</span> : index</span><br><span class="line"> %<span class="number">21</span> = arith.select %<span class="number">20</span>, %<span class="number">19</span>, %<span class="number">16</span> : index</span><br><span class="line"> %<span class="number">22</span> = arith.cmpi slt, %<span class="number">21</span>, %c0 : index</span><br><span class="line"> %<span class="number">23</span> = arith.select %<span class="number">22</span>, %c0, %<span class="number">21</span> : index</span><br><span class="line"> %<span class="number">24</span> = arith.muli %workgroup_id_x, %c256 : index</span><br><span class="line"> %<span class="number">25</span> = arith.addi %<span class="number">17</span>, %<span class="number">24</span> : index</span><br><span class="line"> scf.<span class="keyword">for</span> %arg0 = %c0 to %c100 step %c4 {</span><br><span class="line"> scf.<span class="keyword">for</span> %arg1 = %c0 to %<span class="number">23</span> step %c1 {</span><br><span class="line"> scf.<span class="keyword">for</span> %arg2 = %c0 to %c4 step %c1 {</span><br><span class="line"> %<span class="number">26</span> = affine.apply <span class="built_in">affine_map</span><(d0)[s0] -> (d0 + s0)>(%arg1)[%<span class="number">25</span>]</span><br><span class="line"> %<span class="number">27</span> = affine.apply <span class="built_in">affine_map</span><(d0)[s0] -> (d0 + s0)>(%arg2)[%arg0]</span><br><span class="line"> %<span class="number">28</span> = memref.load %<span class="number">0</span>[%<span class="number">26</span>, %<span class="number">27</span>] : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> %<span class="number">29</span> = affine.apply <span class="built_in">affine_map</span><(d0)[s0] -> (d0 + s0)>(%arg1)[%<span class="number">25</span>]</span><br><span class="line"> %<span class="number">30</span> = affine.apply <span class="built_in">affine_map</span><(d0)[s0] -> (d0 + s0)>(%arg2)[%arg0]</span><br><span class="line"> %<span class="number">31</span> = memref.load %<span class="number">1</span>[%<span class="number">29</span>, %<span class="number">30</span>] : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> %<span class="number">32</span> = affine.apply <span class="built_in">affine_map</span><(d0)[s0] -> (d0 + s0)>(%arg1)[%<span class="number">25</span>]</span><br><span class="line"> %<span class="number">33</span> = memref.load %<span class="number">2</span>[%<span class="number">32</span>] : memref<<span class="number">100000</span>xf32></span><br><span class="line"> %<span class="number">34</span> = arith.addf %<span class="number">28</span>, %<span class="number">31</span> : f32</span><br><span class="line"> %<span class="number">35</span> = arith.addf %<span class="number">34</span>, %<span class="number">33</span> : f32</span><br><span class="line"> %<span class="number">36</span> = affine.apply <span class="built_in">affine_map</span><(d0)[s0] -> (d0 + s0)>(%arg1)[%<span class="number">25</span>]</span><br><span class="line"> memref.store %<span class="number">35</span>, %<span class="number">2</span>[%<span class="number">36</span>] : memref<<span class="number">100000</span>xf32></span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createConvertSCFToCFPass</p><p>将structure的control flow转换成CFG的控制流。上面完整的sourcefunc代码会转换成如下代码,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>() {</span><br><span class="line"> %cst = arith.constant dense<<span class="number">0.000000e+00</span>> : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %c<span class="number">-1</span> = arith.constant <span class="number">-1</span> : index</span><br><span class="line"> %c64 = arith.constant <span class="number">64</span> : index</span><br><span class="line"> %c100000 = arith.constant <span class="number">100000</span> : index</span><br><span class="line"> %c<span class="number">-256</span> = arith.constant <span class="number">-256</span> : index</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c100 = arith.constant <span class="number">100</span> : index</span><br><span class="line"> %c4 = arith.constant <span class="number">4</span> : index</span><br><span class="line"> %c256 = arith.constant <span class="number">256</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">0</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> memref.assume_alignment %<span class="number">0</span>, <span class="number">64</span> : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> %<span class="number">1</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">1</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> memref.assume_alignment %<span class="number">1</span>, <span class="number">64</span> : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> %<span class="number">2</span> = hal.interface.binding.subspan <span class="built_in">set</span>(<span class="number">0</span>) <span class="built_in">binding</span>(<span class="number">2</span>) <span class="built_in">type</span>(storage_buffer) <span class="built_in">offset</span>(%c0) <span class="built_in">alignment</span>(<span class="number">64</span>) : memref<<span class="number">100000</span>xf32></span><br><span class="line"> memref.assume_alignment %<span class="number">2</span>, <span class="number">64</span> : memref<<span class="number">100000</span>xf32></span><br><span class="line"> %workgroup_id_x = hal.interface.workgroup.id[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">3</span> = arith.muli %workgroup_id_x, %c<span class="number">-256</span> : index</span><br><span class="line"> %<span class="number">4</span> = arith.addi %<span class="number">3</span>, %c100000 : index</span><br><span class="line"> %<span class="number">5</span> = arith.cmpi slt, %<span class="number">4</span>, %c256 : index</span><br><span class="line"> %<span class="number">6</span> = arith.select %<span class="number">5</span>, %<span class="number">4</span>, %c256 : index</span><br><span class="line"> %<span class="number">7</span> = arith.cmpi eq, %<span class="number">6</span>, %c256 : index</span><br><span class="line"> cf.cond_br %<span class="number">7</span>, ^bb1, ^bb5</span><br><span class="line"> ^bb1: <span class="comment">// pred: ^bb0</span></span><br><span class="line"> %<span class="number">8</span> = gpu.thread_id x</span><br><span class="line"> %<span class="number">9</span> = arith.muli %<span class="number">8</span>, %c4 : index</span><br><span class="line"> %<span class="number">10</span> = arith.muli %workgroup_id_x, %c256 : index</span><br><span class="line"> %<span class="number">11</span> = arith.addi %<span class="number">9</span>, %<span class="number">10</span> : index</span><br><span class="line"> %<span class="number">12</span> = vector.load %<span class="number">2</span>[%<span class="number">11</span>] : memref<<span class="number">100000</span>xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> cf.br ^<span class="built_in">bb2</span>(%c0, %<span class="number">12</span> : index, vector<<span class="number">4</span>xf32>)</span><br><span class="line"> ^<span class="built_in">bb2</span>(%<span class="number">13</span>: index, %<span class="number">14</span>: vector<<span class="number">4</span>xf32>): <span class="comment">// 2 preds: ^bb1, ^bb3</span></span><br><span class="line"> %<span class="number">15</span> = arith.cmpi slt, %<span class="number">13</span>, %c100 : index</span><br><span class="line"> cf.cond_br %<span class="number">15</span>, ^bb3, ^bb4</span><br><span class="line"> ^bb3: <span class="comment">// pred: ^bb2</span></span><br><span class="line"> %<span class="number">16</span> = vector.load %<span class="number">0</span>[%<span class="number">11</span>, %<span class="number">13</span>] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">17</span> = vector.insert %<span class="number">16</span>, %cst [<span class="number">0</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %c1_0 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %<span class="number">18</span> = arith.addi %<span class="number">11</span>, %c1_0 : index</span><br><span class="line"> %<span class="number">19</span> = vector.load %<span class="number">0</span>[%<span class="number">18</span>, %<span class="number">13</span>] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">20</span> = vector.insert %<span class="number">19</span>, %<span class="number">17</span> [<span class="number">1</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %c2 = arith.constant <span class="number">2</span> : index</span><br><span class="line"> %<span class="number">21</span> = arith.addi %<span class="number">11</span>, %c2 : index</span><br><span class="line"> %<span class="number">22</span> = vector.load %<span class="number">0</span>[%<span class="number">21</span>, %<span class="number">13</span>] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">23</span> = vector.insert %<span class="number">22</span>, %<span class="number">20</span> [<span class="number">2</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %c3 = arith.constant <span class="number">3</span> : index</span><br><span class="line"> %<span class="number">24</span> = arith.addi %<span class="number">11</span>, %c3 : index</span><br><span class="line"> %<span class="number">25</span> = vector.load %<span class="number">0</span>[%<span class="number">24</span>, %<span class="number">13</span>] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">26</span> = vector.insert %<span class="number">25</span>, %<span class="number">23</span> [<span class="number">3</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">27</span> = vector.load %<span class="number">1</span>[%<span class="number">11</span>, %<span class="number">13</span>] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">28</span> = vector.insert %<span class="number">27</span>, %cst [<span class="number">0</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">29</span> = vector.load %<span class="number">1</span>[%<span class="number">18</span>, %<span class="number">13</span>] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">30</span> = vector.insert %<span class="number">29</span>, %<span class="number">28</span> [<span class="number">1</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">31</span> = vector.load %<span class="number">1</span>[%<span class="number">21</span>, %<span class="number">13</span>] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">32</span> = vector.insert %<span class="number">31</span>, %<span class="number">30</span> [<span class="number">2</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">33</span> = vector.load %<span class="number">1</span>[%<span class="number">24</span>, %<span class="number">13</span>] : memref<<span class="number">100000</span>x100xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">34</span> = vector.insert %<span class="number">33</span>, %<span class="number">32</span> [<span class="number">3</span>] : vector<<span class="number">4</span>xf32> into vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">35</span> = arith.addf %<span class="number">26</span>, %<span class="number">34</span> : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">36</span> = vector.transpose %<span class="number">35</span>, [<span class="number">1</span>, <span class="number">0</span>] : vector<<span class="number">4</span>x4xf32> to vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">37</span> = vector.extract %<span class="number">36</span>[<span class="number">0</span>] : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">38</span> = arith.addf %<span class="number">37</span>, %<span class="number">14</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">39</span> = vector.extract %<span class="number">36</span>[<span class="number">1</span>] : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">40</span> = arith.addf %<span class="number">39</span>, %<span class="number">38</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">41</span> = vector.extract %<span class="number">36</span>[<span class="number">2</span>] : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">42</span> = arith.addf %<span class="number">41</span>, %<span class="number">40</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">43</span> = vector.extract %<span class="number">36</span>[<span class="number">3</span>] : vector<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">44</span> = arith.addf %<span class="number">43</span>, %<span class="number">42</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">45</span> = arith.addi %<span class="number">13</span>, %c4 : index</span><br><span class="line"> cf.br ^<span class="built_in">bb2</span>(%<span class="number">45</span>, %<span class="number">44</span> : index, vector<<span class="number">4</span>xf32>)</span><br><span class="line"> ^bb4: <span class="comment">// pred: ^bb2</span></span><br><span class="line"> vector.store %<span class="number">14</span>, %<span class="number">2</span>[%<span class="number">11</span>] : memref<<span class="number">100000</span>xf32>, vector<<span class="number">4</span>xf32></span><br><span class="line"> cf.br ^bb12</span><br><span class="line"> ^bb5: <span class="comment">// pred: ^bb0</span></span><br><span class="line"> %<span class="number">46</span> = gpu.thread_id x</span><br><span class="line"> %<span class="number">47</span> = arith.cmpi sle, %<span class="number">6</span>, %c0 : index</span><br><span class="line"> %<span class="number">48</span> = arith.subi %c0, %<span class="number">6</span> : index</span><br><span class="line"> %<span class="number">49</span> = arith.subi %<span class="number">6</span>, %c1 : index</span><br><span class="line"> %<span class="number">50</span> = arith.select %<span class="number">47</span>, %<span class="number">48</span>, %<span class="number">49</span> : index</span><br><span class="line"> %<span class="number">51</span> = arith.divsi %<span class="number">50</span>, %c64 : index</span><br><span class="line"> %<span class="number">52</span> = arith.subi %c0, %<span class="number">51</span> : index</span><br><span class="line"> %<span class="number">53</span> = arith.addi %<span class="number">51</span>, %c1 : index</span><br><span class="line"> %<span class="number">54</span> = arith.select %<span class="number">47</span>, %<span class="number">52</span>, %<span class="number">53</span> : index</span><br><span class="line"> %<span class="number">55</span> = arith.muli %<span class="number">46</span>, %<span class="number">54</span> : index</span><br><span class="line"> %<span class="number">56</span> = arith.muli %<span class="number">55</span>, %c<span class="number">-1</span> : index</span><br><span class="line"> %<span class="number">57</span> = arith.addi %<span class="number">56</span>, %<span class="number">6</span> : index</span><br><span class="line"> %<span class="number">58</span> = arith.cmpi slt, %<span class="number">57</span>, %<span class="number">54</span> : index</span><br><span class="line"> %<span class="number">59</span> = arith.select %<span class="number">58</span>, %<span class="number">57</span>, %<span class="number">54</span> : index</span><br><span class="line"> %<span class="number">60</span> = arith.cmpi slt, %<span class="number">59</span>, %c0 : index</span><br><span class="line"> %<span class="number">61</span> = arith.select %<span class="number">60</span>, %c0, %<span class="number">59</span> : index</span><br><span class="line"> %<span class="number">62</span> = arith.muli %workgroup_id_x, %c256 : index</span><br><span class="line"> %<span class="number">63</span> = arith.addi %<span class="number">55</span>, %<span class="number">62</span> : index</span><br><span class="line"> cf.br ^<span class="built_in">bb6</span>(%c0 : index)</span><br><span class="line"> ^<span class="built_in">bb6</span>(%<span class="number">64</span>: index): <span class="comment">// 2 preds: ^bb5, ^bb11</span></span><br><span class="line"> %<span class="number">65</span> = arith.cmpi slt, %<span class="number">64</span>, %c100 : index</span><br><span class="line"> cf.cond_br %<span class="number">65</span>, ^<span class="built_in">bb7</span>(%c0 : index), ^bb12</span><br><span class="line"> ^<span class="built_in">bb7</span>(%<span class="number">66</span>: index): <span class="comment">// 2 preds: ^bb6, ^bb10</span></span><br><span class="line"> %<span class="number">67</span> = arith.cmpi slt, %<span class="number">66</span>, %<span class="number">61</span> : index</span><br><span class="line"> cf.cond_br %<span class="number">67</span>, ^<span class="built_in">bb8</span>(%c0 : index), ^bb11</span><br><span class="line"> ^<span class="built_in">bb8</span>(%<span class="number">68</span>: index): <span class="comment">// 2 preds: ^bb7, ^bb9</span></span><br><span class="line"> %<span class="number">69</span> = arith.cmpi slt, %<span class="number">68</span>, %c4 : index</span><br><span class="line"> cf.cond_br %<span class="number">69</span>, ^bb9, ^bb10</span><br><span class="line"> ^bb9: <span class="comment">// pred: ^bb8</span></span><br><span class="line"> %<span class="number">70</span> = arith.addi %<span class="number">63</span>, %<span class="number">66</span> : index</span><br><span class="line"> %<span class="number">71</span> = arith.addi %<span class="number">64</span>, %<span class="number">68</span> : index</span><br><span class="line"> %<span class="number">72</span> = memref.load %<span class="number">0</span>[%<span class="number">70</span>, %<span class="number">71</span>] : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> %<span class="number">73</span> = memref.load %<span class="number">1</span>[%<span class="number">70</span>, %<span class="number">71</span>] : memref<<span class="number">100000</span>x100xf32></span><br><span class="line"> %<span class="number">74</span> = memref.load %<span class="number">2</span>[%<span class="number">70</span>] : memref<<span class="number">100000</span>xf32></span><br><span class="line"> %<span class="number">75</span> = arith.addf %<span class="number">72</span>, %<span class="number">73</span> : f32</span><br><span class="line"> %<span class="number">76</span> = arith.addf %<span class="number">75</span>, %<span class="number">74</span> : f32</span><br><span class="line"> memref.store %<span class="number">76</span>, %<span class="number">2</span>[%<span class="number">70</span>] : memref<<span class="number">100000</span>xf32></span><br><span class="line"> %<span class="number">77</span> = arith.addi %<span class="number">68</span>, %c1 : index</span><br><span class="line"> cf.br ^<span class="built_in">bb8</span>(%<span class="number">77</span> : index)</span><br><span class="line"> ^bb10: <span class="comment">// pred: ^bb8</span></span><br><span class="line"> %<span class="number">78</span> = arith.addi %<span class="number">66</span>, %c1 : index</span><br><span class="line"> cf.br ^<span class="built_in">bb7</span>(%<span class="number">78</span> : index)</span><br><span class="line"> ^bb11: <span class="comment">// pred: ^bb7</span></span><br><span class="line"> %<span class="number">79</span> = arith.addi %<span class="number">64</span>, %c4 : index</span><br><span class="line"> cf.br ^<span class="built_in">bb6</span>(%<span class="number">79</span> : index)</span><br><span class="line"> ^bb12: <span class="comment">// 2 preds: ^bb4, ^bb6</span></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createPolynomialApproximationPass</p></li><li><p>arith::createArithExpandOpsPass</p></li><li><p>memref::createExpandOpsPass</p></li><li><p>memref::createExpandStridedMetadataPass</p></li><li><p>createLowerAffinePass</p></li><li><p>createStripDebugInfoPass</p></li><li><p>createConvertToROCDLPass或createConvertToNVVMPass</p><p>转换到ROCDL IR或NVVM IR。上面完整的sourcefunc代码会转换成如下代码,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br><span class="line">126</span><br><span class="line">127</span><br><span class="line">128</span><br><span class="line">129</span><br><span class="line">130</span><br><span class="line">131</span><br><span class="line">132</span><br><span class="line">133</span><br><span class="line">134</span><br><span class="line">135</span><br><span class="line">136</span><br><span class="line">137</span><br><span class="line">138</span><br><span class="line">139</span><br><span class="line">140</span><br><span class="line">141</span><br><span class="line">142</span><br><span class="line">143</span><br><span class="line">144</span><br><span class="line">145</span><br><span class="line">146</span><br><span class="line">147</span><br><span class="line">148</span><br><span class="line">149</span><br><span class="line">150</span><br><span class="line">151</span><br><span class="line">152</span><br><span class="line">153</span><br><span class="line">154</span><br><span class="line">155</span><br><span class="line">156</span><br><span class="line">157</span><br><span class="line">158</span><br><span class="line">159</span><br><span class="line">160</span><br><span class="line">161</span><br><span class="line">162</span><br><span class="line">163</span><br><span class="line">164</span><br><span class="line">165</span><br><span class="line">166</span><br><span class="line">167</span><br><span class="line">168</span><br><span class="line">169</span><br><span class="line">170</span><br><span class="line">171</span><br><span class="line">172</span><br><span class="line">173</span><br><span class="line">174</span><br><span class="line">175</span><br><span class="line">176</span><br><span class="line">177</span><br><span class="line">178</span><br><span class="line">179</span><br><span class="line">180</span><br><span class="line">181</span><br><span class="line">182</span><br><span class="line">183</span><br><span class="line">184</span><br><span class="line">185</span><br><span class="line">186</span><br><span class="line">187</span><br><span class="line">188</span><br><span class="line">189</span><br><span class="line">190</span><br><span class="line">191</span><br><span class="line">192</span><br><span class="line">193</span><br><span class="line">194</span><br><span class="line">195</span><br><span class="line">196</span><br><span class="line">197</span><br><span class="line">198</span><br><span class="line">199</span><br><span class="line">200</span><br><span class="line">201</span><br><span class="line">202</span><br><span class="line">203</span><br><span class="line">204</span><br><span class="line">205</span><br><span class="line">206</span><br><span class="line">207</span><br><span class="line">208</span><br><span class="line">209</span><br><span class="line">210</span><br><span class="line">211</span><br><span class="line">212</span><br><span class="line">213</span><br><span class="line">214</span><br><span class="line">215</span><br><span class="line">216</span><br><span class="line">217</span><br><span class="line">218</span><br><span class="line">219</span><br><span class="line">220</span><br><span class="line">221</span><br><span class="line">222</span><br><span class="line">223</span><br><span class="line">224</span><br><span class="line">225</span><br><span class="line">226</span><br><span class="line">227</span><br><span class="line">228</span><br><span class="line">229</span><br><span class="line">230</span><br><span class="line">231</span><br><span class="line">232</span><br><span class="line">233</span><br><span class="line">234</span><br><span class="line">235</span><br><span class="line">236</span><br><span class="line">237</span><br><span class="line">238</span><br><span class="line">239</span><br><span class="line">240</span><br><span class="line">241</span><br><span class="line">242</span><br><span class="line">243</span><br><span class="line">244</span><br><span class="line">245</span><br><span class="line">246</span><br><span class="line">247</span><br><span class="line">248</span><br><span class="line">249</span><br><span class="line">250</span><br><span class="line">251</span><br><span class="line">252</span><br><span class="line">253</span><br><span class="line">254</span><br><span class="line">255</span><br><span class="line">256</span><br><span class="line">257</span><br><span class="line">258</span><br><span class="line">259</span><br><span class="line">260</span><br><span class="line">261</span><br><span class="line">262</span><br><span class="line">263</span><br><span class="line">264</span><br><span class="line">265</span><br><span class="line">266</span><br><span class="line">267</span><br><span class="line">268</span><br><span class="line">269</span><br><span class="line">270</span><br><span class="line">271</span><br><span class="line">272</span><br><span class="line">273</span><br><span class="line">274</span><br><span class="line">275</span><br><span class="line">276</span><br><span class="line">277</span><br><span class="line">278</span><br><span class="line">279</span><br><span class="line">280</span><br><span class="line">281</span><br><span class="line">282</span><br><span class="line">283</span><br><span class="line">284</span><br><span class="line">285</span><br><span class="line">286</span><br><span class="line">287</span><br><span class="line">288</span><br><span class="line">289</span><br><span class="line">290</span><br><span class="line">291</span><br><span class="line">292</span><br><span class="line">293</span><br><span class="line">294</span><br><span class="line">295</span><br><span class="line">296</span><br><span class="line">297</span><br><span class="line">298</span><br><span class="line">299</span><br><span class="line">300</span><br><span class="line">301</span><br><span class="line">302</span><br><span class="line">303</span><br><span class="line">304</span><br><span class="line">305</span><br><span class="line">306</span><br><span class="line">307</span><br><span class="line">308</span><br><span class="line">309</span><br><span class="line">310</span><br><span class="line">311</span><br><span class="line">312</span><br><span class="line">313</span><br><span class="line">314</span><br><span class="line">315</span><br><span class="line">316</span><br><span class="line">317</span><br><span class="line">318</span><br><span class="line">319</span><br><span class="line">320</span><br><span class="line">321</span><br><span class="line">322</span><br><span class="line">323</span><br><span class="line">324</span><br><span class="line">325</span><br><span class="line">326</span><br><span class="line">327</span><br><span class="line">328</span><br><span class="line">329</span><br><span class="line">330</span><br><span class="line">331</span><br><span class="line">332</span><br><span class="line">333</span><br><span class="line">334</span><br><span class="line">335</span><br><span class="line">336</span><br><span class="line">337</span><br><span class="line">338</span><br><span class="line">339</span><br><span class="line">340</span><br><span class="line">341</span><br><span class="line">342</span><br><span class="line">343</span><br><span class="line">344</span><br><span class="line">345</span><br><span class="line">346</span><br><span class="line">347</span><br><span class="line">348</span><br><span class="line">349</span><br><span class="line">350</span><br><span class="line">351</span><br><span class="line">352</span><br><span class="line">353</span><br><span class="line">354</span><br><span class="line">355</span><br><span class="line">356</span><br><span class="line">357</span><br><span class="line">358</span><br><span class="line">359</span><br><span class="line">360</span><br><span class="line">361</span><br><span class="line">362</span><br><span class="line">363</span><br><span class="line">364</span><br><span class="line">365</span><br><span class="line">366</span><br><span class="line">367</span><br><span class="line">368</span><br><span class="line">369</span><br><span class="line">370</span><br><span class="line">371</span><br><span class="line">372</span><br><span class="line">373</span><br><span class="line">374</span><br><span class="line">375</span><br><span class="line">376</span><br><span class="line">377</span><br></pre></td><td class="code"><pre><span class="line">llvm.func @<span class="built_in">test_dispatch_0_generic_100000x100</span>(%arg0: !llvm.ptr<f32> {llvm.align = <span class="number">16</span> : i32}, %arg1: !llvm.ptr<f32> {llvm.align = <span class="number">16</span> : i32}, %arg2: !llvm.ptr<f32> {llvm.align = <span class="number">16</span> : i32}) {</span><br><span class="line"> %<span class="number">0</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">3</span> : index) : i64</span><br><span class="line"> %<span class="number">1</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">2</span> : index) : i64</span><br><span class="line"> %<span class="number">2</span> = llvm.mlir.<span class="built_in">constant</span>(dense<<span class="number">0.000000e+00</span>> : vector<<span class="number">4</span>x4xf32>) : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">3</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">-1</span> : index) : i64</span><br><span class="line"> %<span class="number">4</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">64</span> : index) : i64</span><br><span class="line"> %<span class="number">5</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100000</span> : index) : i64</span><br><span class="line"> %<span class="number">6</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">-256</span> : index) : i64</span><br><span class="line"> %<span class="number">7</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">1</span> : index) : i64</span><br><span class="line"> %<span class="number">8</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">9</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">4</span> : index) : i64</span><br><span class="line"> %<span class="number">10</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">256</span> : index) : i64</span><br><span class="line"> %<span class="number">11</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : index) : i64</span><br><span class="line"> %<span class="number">12</span> = llvm.bitcast %arg0 : !llvm.ptr<f32> to !llvm.ptr<i8></span><br><span class="line"> %<span class="number">13</span> = llvm.getelementptr %<span class="number">12</span>[%<span class="number">11</span>] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8></span><br><span class="line"> %<span class="number">14</span> = llvm.bitcast %<span class="number">13</span> : !llvm.ptr<i8> to !llvm.ptr<f32></span><br><span class="line"> %<span class="number">15</span> = llvm.mlir.undef : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">16</span> = llvm.insertvalue %<span class="number">14</span>, %<span class="number">15</span>[<span class="number">0</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">17</span> = llvm.insertvalue %<span class="number">14</span>, %<span class="number">16</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">18</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : index) : i64</span><br><span class="line"> %<span class="number">19</span> = llvm.insertvalue %<span class="number">18</span>, %<span class="number">17</span>[<span class="number">2</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">20</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100000</span> : index) : i64</span><br><span class="line"> %<span class="number">21</span> = llvm.insertvalue %<span class="number">20</span>, %<span class="number">19</span>[<span class="number">3</span>, <span class="number">0</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">22</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">23</span> = llvm.insertvalue %<span class="number">22</span>, %<span class="number">21</span>[<span class="number">4</span>, <span class="number">0</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">24</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">25</span> = llvm.insertvalue %<span class="number">24</span>, %<span class="number">23</span>[<span class="number">3</span>, <span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">26</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">1</span> : index) : i64</span><br><span class="line"> %<span class="number">27</span> = llvm.insertvalue %<span class="number">26</span>, %<span class="number">25</span>[<span class="number">4</span>, <span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">28</span> = llvm.extractvalue %<span class="number">27</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">29</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : index) : i64</span><br><span class="line"> %<span class="number">30</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">63</span> : index) : i64</span><br><span class="line"> %<span class="number">31</span> = llvm.ptrtoint %<span class="number">28</span> : !llvm.ptr<f32> to i64</span><br><span class="line"> %<span class="number">32</span> = llvm.<span class="keyword">and</span> %<span class="number">31</span>, %<span class="number">30</span> : i64</span><br><span class="line"> %<span class="number">33</span> = llvm.icmp <span class="string">"eq"</span> %<span class="number">32</span>, %<span class="number">29</span> : i64</span><br><span class="line"> <span class="string">"llvm.intr.assume"</span>(%<span class="number">33</span>) : (i1) -> ()</span><br><span class="line"> %<span class="number">34</span> = llvm.bitcast %arg1 : !llvm.ptr<f32> to !llvm.ptr<i8></span><br><span class="line"> %<span class="number">35</span> = llvm.getelementptr %<span class="number">34</span>[%<span class="number">11</span>] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8></span><br><span class="line"> %<span class="number">36</span> = llvm.bitcast %<span class="number">35</span> : !llvm.ptr<i8> to !llvm.ptr<f32></span><br><span class="line"> %<span class="number">37</span> = llvm.mlir.undef : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">38</span> = llvm.insertvalue %<span class="number">36</span>, %<span class="number">37</span>[<span class="number">0</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">39</span> = llvm.insertvalue %<span class="number">36</span>, %<span class="number">38</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">40</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : index) : i64</span><br><span class="line"> %<span class="number">41</span> = llvm.insertvalue %<span class="number">40</span>, %<span class="number">39</span>[<span class="number">2</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">42</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100000</span> : index) : i64</span><br><span class="line"> %<span class="number">43</span> = llvm.insertvalue %<span class="number">42</span>, %<span class="number">41</span>[<span class="number">3</span>, <span class="number">0</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">44</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">45</span> = llvm.insertvalue %<span class="number">44</span>, %<span class="number">43</span>[<span class="number">4</span>, <span class="number">0</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">46</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">47</span> = llvm.insertvalue %<span class="number">46</span>, %<span class="number">45</span>[<span class="number">3</span>, <span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">48</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">1</span> : index) : i64</span><br><span class="line"> %<span class="number">49</span> = llvm.insertvalue %<span class="number">48</span>, %<span class="number">47</span>[<span class="number">4</span>, <span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">50</span> = llvm.extractvalue %<span class="number">49</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">51</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : index) : i64</span><br><span class="line"> %<span class="number">52</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">63</span> : index) : i64</span><br><span class="line"> %<span class="number">53</span> = llvm.ptrtoint %<span class="number">50</span> : !llvm.ptr<f32> to i64</span><br><span class="line"> %<span class="number">54</span> = llvm.<span class="keyword">and</span> %<span class="number">53</span>, %<span class="number">52</span> : i64</span><br><span class="line"> %<span class="number">55</span> = llvm.icmp <span class="string">"eq"</span> %<span class="number">54</span>, %<span class="number">51</span> : i64</span><br><span class="line"> <span class="string">"llvm.intr.assume"</span>(%<span class="number">55</span>) : (i1) -> ()</span><br><span class="line"> %<span class="number">56</span> = llvm.bitcast %arg2 : !llvm.ptr<f32> to !llvm.ptr<i8></span><br><span class="line"> %<span class="number">57</span> = llvm.getelementptr %<span class="number">56</span>[%<span class="number">11</span>] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8></span><br><span class="line"> %<span class="number">58</span> = llvm.bitcast %<span class="number">57</span> : !llvm.ptr<i8> to !llvm.ptr<f32></span><br><span class="line"> %<span class="number">59</span> = llvm.mlir.undef : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">1</span> x i64>, array<<span class="number">1</span> x i64>)></span><br><span class="line"> %<span class="number">60</span> = llvm.insertvalue %<span class="number">58</span>, %<span class="number">59</span>[<span class="number">0</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">1</span> x i64>, array<<span class="number">1</span> x i64>)></span><br><span class="line"> %<span class="number">61</span> = llvm.insertvalue %<span class="number">58</span>, %<span class="number">60</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">1</span> x i64>, array<<span class="number">1</span> x i64>)></span><br><span class="line"> %<span class="number">62</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : index) : i64</span><br><span class="line"> %<span class="number">63</span> = llvm.insertvalue %<span class="number">62</span>, %<span class="number">61</span>[<span class="number">2</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">1</span> x i64>, array<<span class="number">1</span> x i64>)></span><br><span class="line"> %<span class="number">64</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100000</span> : index) : i64</span><br><span class="line"> %<span class="number">65</span> = llvm.insertvalue %<span class="number">64</span>, %<span class="number">63</span>[<span class="number">3</span>, <span class="number">0</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">1</span> x i64>, array<<span class="number">1</span> x i64>)></span><br><span class="line"> %<span class="number">66</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">1</span> : index) : i64</span><br><span class="line"> %<span class="number">67</span> = llvm.insertvalue %<span class="number">66</span>, %<span class="number">65</span>[<span class="number">4</span>, <span class="number">0</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">1</span> x i64>, array<<span class="number">1</span> x i64>)></span><br><span class="line"> %<span class="number">68</span> = llvm.extractvalue %<span class="number">67</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">1</span> x i64>, array<<span class="number">1</span> x i64>)></span><br><span class="line"> %<span class="number">69</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : index) : i64</span><br><span class="line"> %<span class="number">70</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">63</span> : index) : i64</span><br><span class="line"> %<span class="number">71</span> = llvm.ptrtoint %<span class="number">68</span> : !llvm.ptr<f32> to i64</span><br><span class="line"> %<span class="number">72</span> = llvm.<span class="keyword">and</span> %<span class="number">71</span>, %<span class="number">70</span> : i64</span><br><span class="line"> %<span class="number">73</span> = llvm.icmp <span class="string">"eq"</span> %<span class="number">72</span>, %<span class="number">69</span> : i64</span><br><span class="line"> <span class="string">"llvm.intr.assume"</span>(%<span class="number">73</span>) : (i1) -> ()</span><br><span class="line"> %<span class="number">74</span> = nvvm.read.ptx.sreg.ctaid.x : i32</span><br><span class="line"> %<span class="number">75</span> = llvm.sext %<span class="number">74</span> : i32 to i64</span><br><span class="line"> %<span class="number">76</span> = llvm.mul %<span class="number">75</span>, %<span class="number">6</span> : i64</span><br><span class="line"> %<span class="number">77</span> = llvm.add %<span class="number">76</span>, %<span class="number">5</span> : i64</span><br><span class="line"> %<span class="number">78</span> = llvm.icmp <span class="string">"slt"</span> %<span class="number">77</span>, %<span class="number">10</span> : i64</span><br><span class="line"> %<span class="number">79</span> = llvm.select %<span class="number">78</span>, %<span class="number">77</span>, %<span class="number">10</span> : i1, i64</span><br><span class="line"> %<span class="number">80</span> = llvm.icmp <span class="string">"eq"</span> %<span class="number">79</span>, %<span class="number">10</span> : i64</span><br><span class="line"> llvm.cond_br %<span class="number">80</span>, ^bb1, ^bb5</span><br><span class="line"> ^bb1: <span class="comment">// pred: ^bb0</span></span><br><span class="line"> %<span class="number">81</span> = nvvm.read.ptx.sreg.tid.x : i32</span><br><span class="line"> %<span class="number">82</span> = llvm.sext %<span class="number">81</span> : i32 to i64</span><br><span class="line"> %<span class="number">83</span> = llvm.mul %<span class="number">82</span>, %<span class="number">9</span> : i64</span><br><span class="line"> %<span class="number">84</span> = llvm.mul %<span class="number">75</span>, %<span class="number">10</span> : i64</span><br><span class="line"> %<span class="number">85</span> = llvm.add %<span class="number">83</span>, %<span class="number">84</span> : i64</span><br><span class="line"> %<span class="number">86</span> = llvm.extractvalue %<span class="number">67</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">1</span> x i64>, array<<span class="number">1</span> x i64>)></span><br><span class="line"> %<span class="number">87</span> = llvm.getelementptr %<span class="number">86</span>[%<span class="number">85</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">88</span> = llvm.bitcast %<span class="number">87</span> : !llvm.ptr<f32> to !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">89</span> = llvm.load %<span class="number">88</span> {alignment = <span class="number">4</span> : i64} : !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> llvm.br ^<span class="built_in">bb2</span>(%<span class="number">11</span>, %<span class="number">89</span> : i64, vector<<span class="number">4</span>xf32>)</span><br><span class="line"> ^<span class="built_in">bb2</span>(%<span class="number">90</span>: i64, %<span class="number">91</span>: vector<<span class="number">4</span>xf32>): <span class="comment">// 2 preds: ^bb1, ^bb3</span></span><br><span class="line"> %<span class="number">92</span> = llvm.icmp <span class="string">"slt"</span> %<span class="number">90</span>, %<span class="number">8</span> : i64</span><br><span class="line"> llvm.cond_br %<span class="number">92</span>, ^bb3, ^bb4</span><br><span class="line"> ^bb3: <span class="comment">// pred: ^bb2</span></span><br><span class="line"> %<span class="number">93</span> = llvm.extractvalue %<span class="number">27</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">94</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">95</span> = llvm.mul %<span class="number">85</span>, %<span class="number">94</span> : i64</span><br><span class="line"> %<span class="number">96</span> = llvm.add %<span class="number">95</span>, %<span class="number">90</span> : i64</span><br><span class="line"> %<span class="number">97</span> = llvm.getelementptr %<span class="number">93</span>[%<span class="number">96</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">98</span> = llvm.bitcast %<span class="number">97</span> : !llvm.ptr<f32> to !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">99</span> = llvm.load %<span class="number">98</span> {alignment = <span class="number">4</span> : i64} : !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">100</span> = llvm.insertvalue %<span class="number">99</span>, %<span class="number">2</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">101</span> = llvm.add %<span class="number">85</span>, %<span class="number">7</span> : i64</span><br><span class="line"> %<span class="number">102</span> = llvm.extractvalue %<span class="number">27</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">103</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">104</span> = llvm.mul %<span class="number">101</span>, %<span class="number">103</span> : i64</span><br><span class="line"> %<span class="number">105</span> = llvm.add %<span class="number">104</span>, %<span class="number">90</span> : i64</span><br><span class="line"> %<span class="number">106</span> = llvm.getelementptr %<span class="number">102</span>[%<span class="number">105</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">107</span> = llvm.bitcast %<span class="number">106</span> : !llvm.ptr<f32> to !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">108</span> = llvm.load %<span class="number">107</span> {alignment = <span class="number">4</span> : i64} : !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">109</span> = llvm.insertvalue %<span class="number">108</span>, %<span class="number">100</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">110</span> = llvm.add %<span class="number">85</span>, %<span class="number">1</span> : i64</span><br><span class="line"> %<span class="number">111</span> = llvm.extractvalue %<span class="number">27</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">112</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">113</span> = llvm.mul %<span class="number">110</span>, %<span class="number">112</span> : i64</span><br><span class="line"> %<span class="number">114</span> = llvm.add %<span class="number">113</span>, %<span class="number">90</span> : i64</span><br><span class="line"> %<span class="number">115</span> = llvm.getelementptr %<span class="number">111</span>[%<span class="number">114</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">116</span> = llvm.bitcast %<span class="number">115</span> : !llvm.ptr<f32> to !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">117</span> = llvm.load %<span class="number">116</span> {alignment = <span class="number">4</span> : i64} : !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">118</span> = llvm.insertvalue %<span class="number">117</span>, %<span class="number">109</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">119</span> = llvm.add %<span class="number">85</span>, %<span class="number">0</span> : i64</span><br><span class="line"> %<span class="number">120</span> = llvm.extractvalue %<span class="number">27</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">121</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">122</span> = llvm.mul %<span class="number">119</span>, %<span class="number">121</span> : i64</span><br><span class="line"> %<span class="number">123</span> = llvm.add %<span class="number">122</span>, %<span class="number">90</span> : i64</span><br><span class="line"> %<span class="number">124</span> = llvm.getelementptr %<span class="number">120</span>[%<span class="number">123</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">125</span> = llvm.bitcast %<span class="number">124</span> : !llvm.ptr<f32> to !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">126</span> = llvm.load %<span class="number">125</span> {alignment = <span class="number">4</span> : i64} : !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">127</span> = llvm.insertvalue %<span class="number">126</span>, %<span class="number">118</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">128</span> = llvm.extractvalue %<span class="number">49</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">129</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">130</span> = llvm.mul %<span class="number">85</span>, %<span class="number">129</span> : i64</span><br><span class="line"> %<span class="number">131</span> = llvm.add %<span class="number">130</span>, %<span class="number">90</span> : i64</span><br><span class="line"> %<span class="number">132</span> = llvm.getelementptr %<span class="number">128</span>[%<span class="number">131</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">133</span> = llvm.bitcast %<span class="number">132</span> : !llvm.ptr<f32> to !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">134</span> = llvm.load %<span class="number">133</span> {alignment = <span class="number">4</span> : i64} : !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">135</span> = llvm.insertvalue %<span class="number">134</span>, %<span class="number">2</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">136</span> = llvm.extractvalue %<span class="number">49</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">137</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">138</span> = llvm.mul %<span class="number">101</span>, %<span class="number">137</span> : i64</span><br><span class="line"> %<span class="number">139</span> = llvm.add %<span class="number">138</span>, %<span class="number">90</span> : i64</span><br><span class="line"> %<span class="number">140</span> = llvm.getelementptr %<span class="number">136</span>[%<span class="number">139</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">141</span> = llvm.bitcast %<span class="number">140</span> : !llvm.ptr<f32> to !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">142</span> = llvm.load %<span class="number">141</span> {alignment = <span class="number">4</span> : i64} : !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">143</span> = llvm.insertvalue %<span class="number">142</span>, %<span class="number">135</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">144</span> = llvm.extractvalue %<span class="number">49</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">145</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">146</span> = llvm.mul %<span class="number">110</span>, %<span class="number">145</span> : i64</span><br><span class="line"> %<span class="number">147</span> = llvm.add %<span class="number">146</span>, %<span class="number">90</span> : i64</span><br><span class="line"> %<span class="number">148</span> = llvm.getelementptr %<span class="number">144</span>[%<span class="number">147</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">149</span> = llvm.bitcast %<span class="number">148</span> : !llvm.ptr<f32> to !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">150</span> = llvm.load %<span class="number">149</span> {alignment = <span class="number">4</span> : i64} : !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">151</span> = llvm.insertvalue %<span class="number">150</span>, %<span class="number">143</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">152</span> = llvm.extractvalue %<span class="number">49</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">153</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">154</span> = llvm.mul %<span class="number">119</span>, %<span class="number">153</span> : i64</span><br><span class="line"> %<span class="number">155</span> = llvm.add %<span class="number">154</span>, %<span class="number">90</span> : i64</span><br><span class="line"> %<span class="number">156</span> = llvm.getelementptr %<span class="number">152</span>[%<span class="number">155</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">157</span> = llvm.bitcast %<span class="number">156</span> : !llvm.ptr<f32> to !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">158</span> = llvm.load %<span class="number">157</span> {alignment = <span class="number">4</span> : i64} : !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">159</span> = llvm.insertvalue %<span class="number">158</span>, %<span class="number">151</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">160</span> = llvm.mlir.undef : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">161</span> = llvm.extractvalue %<span class="number">127</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">162</span> = llvm.extractvalue %<span class="number">159</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">163</span> = llvm.fadd %<span class="number">161</span>, %<span class="number">162</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">164</span> = llvm.insertvalue %<span class="number">163</span>, %<span class="number">160</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">165</span> = llvm.extractvalue %<span class="number">127</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">166</span> = llvm.extractvalue %<span class="number">159</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">167</span> = llvm.fadd %<span class="number">165</span>, %<span class="number">166</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">168</span> = llvm.insertvalue %<span class="number">167</span>, %<span class="number">164</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">169</span> = llvm.extractvalue %<span class="number">127</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">170</span> = llvm.extractvalue %<span class="number">159</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">171</span> = llvm.fadd %<span class="number">169</span>, %<span class="number">170</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">172</span> = llvm.insertvalue %<span class="number">171</span>, %<span class="number">168</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">173</span> = llvm.extractvalue %<span class="number">127</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">174</span> = llvm.extractvalue %<span class="number">159</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">175</span> = llvm.fadd %<span class="number">173</span>, %<span class="number">174</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">176</span> = llvm.insertvalue %<span class="number">175</span>, %<span class="number">172</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">177</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">178</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : i64) : i64</span><br><span class="line"> %<span class="number">179</span> = llvm.extractelement %<span class="number">177</span>[%<span class="number">178</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">180</span> = llvm.extractvalue %<span class="number">2</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">181</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : i64) : i64</span><br><span class="line"> %<span class="number">182</span> = llvm.insertelement %<span class="number">179</span>, %<span class="number">180</span>[%<span class="number">181</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">183</span> = llvm.insertvalue %<span class="number">182</span>, %<span class="number">2</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">184</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">185</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">1</span> : i64) : i64</span><br><span class="line"> %<span class="number">186</span> = llvm.extractelement %<span class="number">184</span>[%<span class="number">185</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">187</span> = llvm.extractvalue %<span class="number">183</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">188</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : i64) : i64</span><br><span class="line"> %<span class="number">189</span> = llvm.insertelement %<span class="number">186</span>, %<span class="number">187</span>[%<span class="number">188</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">190</span> = llvm.insertvalue %<span class="number">189</span>, %<span class="number">183</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">191</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">192</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">2</span> : i64) : i64</span><br><span class="line"> %<span class="number">193</span> = llvm.extractelement %<span class="number">191</span>[%<span class="number">192</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">194</span> = llvm.extractvalue %<span class="number">190</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">195</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : i64) : i64</span><br><span class="line"> %<span class="number">196</span> = llvm.insertelement %<span class="number">193</span>, %<span class="number">194</span>[%<span class="number">195</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">197</span> = llvm.insertvalue %<span class="number">196</span>, %<span class="number">190</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">198</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">199</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">3</span> : i64) : i64</span><br><span class="line"> %<span class="number">200</span> = llvm.extractelement %<span class="number">198</span>[%<span class="number">199</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">201</span> = llvm.extractvalue %<span class="number">197</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">202</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : i64) : i64</span><br><span class="line"> %<span class="number">203</span> = llvm.insertelement %<span class="number">200</span>, %<span class="number">201</span>[%<span class="number">202</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">204</span> = llvm.insertvalue %<span class="number">203</span>, %<span class="number">197</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">205</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">206</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : i64) : i64</span><br><span class="line"> %<span class="number">207</span> = llvm.extractelement %<span class="number">205</span>[%<span class="number">206</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">208</span> = llvm.extractvalue %<span class="number">204</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">209</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">1</span> : i64) : i64</span><br><span class="line"> %<span class="number">210</span> = llvm.insertelement %<span class="number">207</span>, %<span class="number">208</span>[%<span class="number">209</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">211</span> = llvm.insertvalue %<span class="number">210</span>, %<span class="number">204</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">212</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">213</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">1</span> : i64) : i64</span><br><span class="line"> %<span class="number">214</span> = llvm.extractelement %<span class="number">212</span>[%<span class="number">213</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">215</span> = llvm.extractvalue %<span class="number">211</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">216</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">1</span> : i64) : i64</span><br><span class="line"> %<span class="number">217</span> = llvm.insertelement %<span class="number">214</span>, %<span class="number">215</span>[%<span class="number">216</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">218</span> = llvm.insertvalue %<span class="number">217</span>, %<span class="number">211</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">219</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">220</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">2</span> : i64) : i64</span><br><span class="line"> %<span class="number">221</span> = llvm.extractelement %<span class="number">219</span>[%<span class="number">220</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">222</span> = llvm.extractvalue %<span class="number">218</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">223</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">1</span> : i64) : i64</span><br><span class="line"> %<span class="number">224</span> = llvm.insertelement %<span class="number">221</span>, %<span class="number">222</span>[%<span class="number">223</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">225</span> = llvm.insertvalue %<span class="number">224</span>, %<span class="number">218</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">226</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">227</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">3</span> : i64) : i64</span><br><span class="line"> %<span class="number">228</span> = llvm.extractelement %<span class="number">226</span>[%<span class="number">227</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">229</span> = llvm.extractvalue %<span class="number">225</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">230</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">1</span> : i64) : i64</span><br><span class="line"> %<span class="number">231</span> = llvm.insertelement %<span class="number">228</span>, %<span class="number">229</span>[%<span class="number">230</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">232</span> = llvm.insertvalue %<span class="number">231</span>, %<span class="number">225</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">233</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">234</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : i64) : i64</span><br><span class="line"> %<span class="number">235</span> = llvm.extractelement %<span class="number">233</span>[%<span class="number">234</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">236</span> = llvm.extractvalue %<span class="number">232</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">237</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">2</span> : i64) : i64</span><br><span class="line"> %<span class="number">238</span> = llvm.insertelement %<span class="number">235</span>, %<span class="number">236</span>[%<span class="number">237</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">239</span> = llvm.insertvalue %<span class="number">238</span>, %<span class="number">232</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">240</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">241</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">1</span> : i64) : i64</span><br><span class="line"> %<span class="number">242</span> = llvm.extractelement %<span class="number">240</span>[%<span class="number">241</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">243</span> = llvm.extractvalue %<span class="number">239</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">244</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">2</span> : i64) : i64</span><br><span class="line"> %<span class="number">245</span> = llvm.insertelement %<span class="number">242</span>, %<span class="number">243</span>[%<span class="number">244</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">246</span> = llvm.insertvalue %<span class="number">245</span>, %<span class="number">239</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">247</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">248</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">2</span> : i64) : i64</span><br><span class="line"> %<span class="number">249</span> = llvm.extractelement %<span class="number">247</span>[%<span class="number">248</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">250</span> = llvm.extractvalue %<span class="number">246</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">251</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">2</span> : i64) : i64</span><br><span class="line"> %<span class="number">252</span> = llvm.insertelement %<span class="number">249</span>, %<span class="number">250</span>[%<span class="number">251</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">253</span> = llvm.insertvalue %<span class="number">252</span>, %<span class="number">246</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">254</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">255</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">3</span> : i64) : i64</span><br><span class="line"> %<span class="number">256</span> = llvm.extractelement %<span class="number">254</span>[%<span class="number">255</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">257</span> = llvm.extractvalue %<span class="number">253</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">258</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">2</span> : i64) : i64</span><br><span class="line"> %<span class="number">259</span> = llvm.insertelement %<span class="number">256</span>, %<span class="number">257</span>[%<span class="number">258</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">260</span> = llvm.insertvalue %<span class="number">259</span>, %<span class="number">253</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">261</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">262</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">0</span> : i64) : i64</span><br><span class="line"> %<span class="number">263</span> = llvm.extractelement %<span class="number">261</span>[%<span class="number">262</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">264</span> = llvm.extractvalue %<span class="number">260</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">265</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">3</span> : i64) : i64</span><br><span class="line"> %<span class="number">266</span> = llvm.insertelement %<span class="number">263</span>, %<span class="number">264</span>[%<span class="number">265</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">267</span> = llvm.insertvalue %<span class="number">266</span>, %<span class="number">260</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">268</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">269</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">1</span> : i64) : i64</span><br><span class="line"> %<span class="number">270</span> = llvm.extractelement %<span class="number">268</span>[%<span class="number">269</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">271</span> = llvm.extractvalue %<span class="number">267</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">272</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">3</span> : i64) : i64</span><br><span class="line"> %<span class="number">273</span> = llvm.insertelement %<span class="number">270</span>, %<span class="number">271</span>[%<span class="number">272</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">274</span> = llvm.insertvalue %<span class="number">273</span>, %<span class="number">267</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">275</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">276</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">2</span> : i64) : i64</span><br><span class="line"> %<span class="number">277</span> = llvm.extractelement %<span class="number">275</span>[%<span class="number">276</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">278</span> = llvm.extractvalue %<span class="number">274</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">279</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">3</span> : i64) : i64</span><br><span class="line"> %<span class="number">280</span> = llvm.insertelement %<span class="number">277</span>, %<span class="number">278</span>[%<span class="number">279</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">281</span> = llvm.insertvalue %<span class="number">280</span>, %<span class="number">274</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">282</span> = llvm.extractvalue %<span class="number">176</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">283</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">3</span> : i64) : i64</span><br><span class="line"> %<span class="number">284</span> = llvm.extractelement %<span class="number">282</span>[%<span class="number">283</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">285</span> = llvm.extractvalue %<span class="number">281</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">286</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">3</span> : i64) : i64</span><br><span class="line"> %<span class="number">287</span> = llvm.insertelement %<span class="number">284</span>, %<span class="number">285</span>[%<span class="number">286</span> : i64] : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">288</span> = llvm.insertvalue %<span class="number">287</span>, %<span class="number">281</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">289</span> = llvm.extractvalue %<span class="number">288</span>[<span class="number">0</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">290</span> = llvm.fadd %<span class="number">289</span>, %<span class="number">91</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">291</span> = llvm.extractvalue %<span class="number">288</span>[<span class="number">1</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">292</span> = llvm.fadd %<span class="number">291</span>, %<span class="number">290</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">293</span> = llvm.extractvalue %<span class="number">288</span>[<span class="number">2</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">294</span> = llvm.fadd %<span class="number">293</span>, %<span class="number">292</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">295</span> = llvm.extractvalue %<span class="number">288</span>[<span class="number">3</span>] : !llvm.array<<span class="number">4</span> x vector<<span class="number">4</span>xf32>></span><br><span class="line"> %<span class="number">296</span> = llvm.fadd %<span class="number">295</span>, %<span class="number">294</span> : vector<<span class="number">4</span>xf32></span><br><span class="line"> %<span class="number">297</span> = llvm.add %<span class="number">90</span>, %<span class="number">9</span> : i64</span><br><span class="line"> llvm.br ^<span class="built_in">bb2</span>(%<span class="number">297</span>, %<span class="number">296</span> : i64, vector<<span class="number">4</span>xf32>)</span><br><span class="line"> ^bb4: <span class="comment">// pred: ^bb2</span></span><br><span class="line"> %<span class="number">298</span> = llvm.extractvalue %<span class="number">67</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">1</span> x i64>, array<<span class="number">1</span> x i64>)></span><br><span class="line"> %<span class="number">299</span> = llvm.getelementptr %<span class="number">298</span>[%<span class="number">85</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">300</span> = llvm.bitcast %<span class="number">299</span> : !llvm.ptr<f32> to !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> llvm.store %<span class="number">91</span>, %<span class="number">300</span> {alignment = <span class="number">4</span> : i64} : !llvm.ptr<vector<<span class="number">4</span>xf32>></span><br><span class="line"> llvm.br ^bb12</span><br><span class="line"> ^bb5: <span class="comment">// pred: ^bb0</span></span><br><span class="line"> %<span class="number">301</span> = nvvm.read.ptx.sreg.tid.x : i32</span><br><span class="line"> %<span class="number">302</span> = llvm.sext %<span class="number">301</span> : i32 to i64</span><br><span class="line"> %<span class="number">303</span> = llvm.icmp <span class="string">"sle"</span> %<span class="number">79</span>, %<span class="number">11</span> : i64</span><br><span class="line"> %<span class="number">304</span> = llvm.sub %<span class="number">11</span>, %<span class="number">79</span> : i64</span><br><span class="line"> %<span class="number">305</span> = llvm.sub %<span class="number">79</span>, %<span class="number">7</span> : i64</span><br><span class="line"> %<span class="number">306</span> = llvm.select %<span class="number">303</span>, %<span class="number">304</span>, %<span class="number">305</span> : i1, i64</span><br><span class="line"> %<span class="number">307</span> = llvm.sdiv %<span class="number">306</span>, %<span class="number">4</span> : i64</span><br><span class="line"> %<span class="number">308</span> = llvm.sub %<span class="number">11</span>, %<span class="number">307</span> : i64</span><br><span class="line"> %<span class="number">309</span> = llvm.add %<span class="number">307</span>, %<span class="number">7</span> : i64</span><br><span class="line"> %<span class="number">310</span> = llvm.select %<span class="number">303</span>, %<span class="number">308</span>, %<span class="number">309</span> : i1, i64</span><br><span class="line"> %<span class="number">311</span> = llvm.mul %<span class="number">302</span>, %<span class="number">310</span> : i64</span><br><span class="line"> %<span class="number">312</span> = llvm.mul %<span class="number">311</span>, %<span class="number">3</span> : i64</span><br><span class="line"> %<span class="number">313</span> = llvm.add %<span class="number">312</span>, %<span class="number">79</span> : i64</span><br><span class="line"> %<span class="number">314</span> = llvm.icmp <span class="string">"slt"</span> %<span class="number">313</span>, %<span class="number">310</span> : i64</span><br><span class="line"> %<span class="number">315</span> = llvm.select %<span class="number">314</span>, %<span class="number">313</span>, %<span class="number">310</span> : i1, i64</span><br><span class="line"> %<span class="number">316</span> = llvm.icmp <span class="string">"slt"</span> %<span class="number">315</span>, %<span class="number">11</span> : i64</span><br><span class="line"> %<span class="number">317</span> = llvm.select %<span class="number">316</span>, %<span class="number">11</span>, %<span class="number">315</span> : i1, i64</span><br><span class="line"> %<span class="number">318</span> = llvm.mul %<span class="number">75</span>, %<span class="number">10</span> : i64</span><br><span class="line"> %<span class="number">319</span> = llvm.add %<span class="number">311</span>, %<span class="number">318</span> : i64</span><br><span class="line"> llvm.br ^<span class="built_in">bb6</span>(%<span class="number">11</span> : i64)</span><br><span class="line"> ^<span class="built_in">bb6</span>(%<span class="number">320</span>: i64): <span class="comment">// 2 preds: ^bb5, ^bb11</span></span><br><span class="line"> %<span class="number">321</span> = llvm.icmp <span class="string">"slt"</span> %<span class="number">320</span>, %<span class="number">8</span> : i64</span><br><span class="line"> llvm.cond_br %<span class="number">321</span>, ^<span class="built_in">bb7</span>(%<span class="number">11</span> : i64), ^bb12</span><br><span class="line"> ^<span class="built_in">bb7</span>(%<span class="number">322</span>: i64): <span class="comment">// 2 preds: ^bb6, ^bb10</span></span><br><span class="line"> %<span class="number">323</span> = llvm.icmp <span class="string">"slt"</span> %<span class="number">322</span>, %<span class="number">317</span> : i64</span><br><span class="line"> llvm.cond_br %<span class="number">323</span>, ^<span class="built_in">bb8</span>(%<span class="number">11</span> : i64), ^bb11</span><br><span class="line"> ^<span class="built_in">bb8</span>(%<span class="number">324</span>: i64): <span class="comment">// 2 preds: ^bb7, ^bb9</span></span><br><span class="line"> %<span class="number">325</span> = llvm.icmp <span class="string">"slt"</span> %<span class="number">324</span>, %<span class="number">9</span> : i64</span><br><span class="line"> llvm.cond_br %<span class="number">325</span>, ^bb9, ^bb10</span><br><span class="line"> ^bb9: <span class="comment">// pred: ^bb8</span></span><br><span class="line"> %<span class="number">326</span> = llvm.add %<span class="number">319</span>, %<span class="number">322</span> : i64</span><br><span class="line"> %<span class="number">327</span> = llvm.add %<span class="number">320</span>, %<span class="number">324</span> : i64</span><br><span class="line"> %<span class="number">328</span> = llvm.extractvalue %<span class="number">27</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">329</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">330</span> = llvm.mul %<span class="number">326</span>, %<span class="number">329</span> : i64</span><br><span class="line"> %<span class="number">331</span> = llvm.add %<span class="number">330</span>, %<span class="number">327</span> : i64</span><br><span class="line"> %<span class="number">332</span> = llvm.getelementptr %<span class="number">328</span>[%<span class="number">331</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">333</span> = llvm.load %<span class="number">332</span> : !llvm.ptr<f32></span><br><span class="line"> %<span class="number">334</span> = llvm.extractvalue %<span class="number">49</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">2</span> x i64>, array<<span class="number">2</span> x i64>)></span><br><span class="line"> %<span class="number">335</span> = llvm.mlir.<span class="built_in">constant</span>(<span class="number">100</span> : index) : i64</span><br><span class="line"> %<span class="number">336</span> = llvm.mul %<span class="number">326</span>, %<span class="number">335</span> : i64</span><br><span class="line"> %<span class="number">337</span> = llvm.add %<span class="number">336</span>, %<span class="number">327</span> : i64</span><br><span class="line"> %<span class="number">338</span> = llvm.getelementptr %<span class="number">334</span>[%<span class="number">337</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">339</span> = llvm.load %<span class="number">338</span> : !llvm.ptr<f32></span><br><span class="line"> %<span class="number">340</span> = llvm.extractvalue %<span class="number">67</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">1</span> x i64>, array<<span class="number">1</span> x i64>)></span><br><span class="line"> %<span class="number">341</span> = llvm.getelementptr %<span class="number">340</span>[%<span class="number">326</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> %<span class="number">342</span> = llvm.load %<span class="number">341</span> : !llvm.ptr<f32></span><br><span class="line"> %<span class="number">343</span> = llvm.fadd %<span class="number">333</span>, %<span class="number">339</span> : f32</span><br><span class="line"> %<span class="number">344</span> = llvm.fadd %<span class="number">343</span>, %<span class="number">342</span> : f32</span><br><span class="line"> %<span class="number">345</span> = llvm.extractvalue %<span class="number">67</span>[<span class="number">1</span>] : !llvm.<span class="keyword">struct</span><(ptr<f32>, ptr<f32>, i64, array<<span class="number">1</span> x i64>, array<<span class="number">1</span> x i64>)></span><br><span class="line"> %<span class="number">346</span> = llvm.getelementptr %<span class="number">345</span>[%<span class="number">326</span>] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32></span><br><span class="line"> llvm.store %<span class="number">344</span>, %<span class="number">346</span> : !llvm.ptr<f32></span><br><span class="line"> %<span class="number">347</span> = llvm.add %<span class="number">324</span>, %<span class="number">7</span> : i64</span><br><span class="line"> llvm.br ^<span class="built_in">bb8</span>(%<span class="number">347</span> : i64)</span><br><span class="line"> ^bb10: <span class="comment">// pred: ^bb8</span></span><br><span class="line"> %<span class="number">348</span> = llvm.add %<span class="number">322</span>, %<span class="number">7</span> : i64</span><br><span class="line"> llvm.br ^<span class="built_in">bb7</span>(%<span class="number">348</span> : i64)</span><br><span class="line"> ^bb11: <span class="comment">// pred: ^bb7</span></span><br><span class="line"> %<span class="number">349</span> = llvm.add %<span class="number">320</span>, %<span class="number">9</span> : i64</span><br><span class="line"> llvm.br ^<span class="built_in">bb6</span>(%<span class="number">349</span> : i64)</span><br><span class="line"> ^bb12: <span class="comment">// 2 preds: ^bb4, ^bb6</span></span><br><span class="line"> llvm.<span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li></ul></li></ul></li></ul></li><li><p>createConvertToHALPass</p></li><li><p>createFixupLegacySyncPass</p></li><li><p>addCleanupPatterns</p></li><li><p>createLinkExecutablesPass</p></li><li><p>createResolveExportOrdinalsPass</p></li><li><p>createMaterializeResourceCachesPass</p></li><li><p>createInlineDeviceSwitchesPass</p></li><li><p>createMemoizeDeviceQueriesPass</p></li><li><p>addCleanupPatterns</p></li><li><p>createElideRedundantCommandsPass</p></li><li><p>mlir::createLowerAffinePass</p></li><li><p>mlir::createConvertSCFToCFPass</p></li><li><p>IREE::Util::createCombineInitializersPass</p></li><li><p>addCleanupPatterns</p></li><li><p>createSerializeExecutablesPass</p></li><li><p>mlir::createSymbolDCEPass</p></li></ul>]]></content>
<summary type="html"><p>HAL::HALTransformPassPipeline的主要作用是进行tiling、vectorization和bufferization等操作,分配计算负载,最终生成target
device的代码。比如cuda target的dispatch source code会被递降为NVVM
IR。</p></summary>
<category term="DL Compiler" scheme="https://hjchen2.github.io/categories/DL-Compiler/"/>
<category term="Deep Learning Compiler" scheme="https://hjchen2.github.io/tags/Deep-Learning-Compiler/"/>
<category term="IREE" scheme="https://hjchen2.github.io/tags/IREE/"/>
</entry>
<entry>
<title>IREE编译流程解析(五)</title>
<link href="https://hjchen2.github.io/2023/02/13/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B5/"/>
<id>https://hjchen2.github.io/2023/02/13/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B5/</id>
<published>2023-02-13T13:57:20.000Z</published>
<updated>2023-02-17T11:32:02.532Z</updated>
<content type="html"><![CDATA[<p>IREE::Stream::StreamTransformPassPipeline的主要作用是将program转换到streamdialect,优化变量编码方式,划分调度子图,生成异步调度策略,并实现内存规划策略。</p><span id="more"></span><ul><li><p>buildStreamTensorPassPipeline</p><ul><li><p>IREE::Stream::createVerifyInputPass</p><p>检查program的合法性。</p></li><li><p>IREE::Stream::createOutlineConstantsPass</p><p>将module内部的dense constant转换成global dense constant。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant dense<[<span class="number">0.000000e+00</span>, <span class="number">0.00999999977</span>, <span class="number">2.000000e-02</span>, <span class="number">3.000000e-02</span>, <span class="number">4.000000e-02</span>, <span class="number">5.000000e-02</span>, <span class="number">6.000000e-02</span>, <span class="number">7.000000e-02</span>, <span class="number">8.000000e-02</span>, <span class="number">9.000000e-02</span>]> : tensor<<span class="number">10</span>xf32></span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x10xf32></span><br><span class="line"> %<span class="number">1</span> = flow.tensor.reshape %<span class="number">0</span> : tensor<<span class="number">1</span>x10xf32> -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">2</span> = flow.tensor.empty : tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">3</span> = flow.dispatch @test_dispatch_0::@test_dispatch_0_generic_10[%c10](%<span class="number">1</span>, %cst, %<span class="number">2</span>) : (tensor<<span class="number">10</span>xf32>, tensor<<span class="number">10</span>xf32>, tensor<<span class="number">10</span>xf32>) -> %<span class="number">2</span></span><br><span class="line"> %<span class="number">4</span> = flow.tensor.reshape %<span class="number">3</span> : tensor<<span class="number">10</span>xf32> -> tensor<<span class="number">1</span>x10xf32></span><br><span class="line"> %<span class="number">5</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">4</span> : tensor<<span class="number">1</span>x10xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">5</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line">util.global <span class="keyword">private</span> @_constant {noinline} = dense<[<span class="number">0.000000e+00</span>, <span class="number">0.00999999977</span>, <span class="number">2.000000e-02</span>, <span class="number">3.000000e-02</span>, <span class="number">4.000000e-02</span>, <span class="number">5.000000e-02</span>, <span class="number">6.000000e-02</span>, <span class="number">7.000000e-02</span>, <span class="number">8.000000e-02</span>, <span class="number">9.000000e-02</span>]> : tensor<<span class="number">10</span>xf32></span><br><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %_constant = util.global.load @_constant : tensor<<span class="number">10</span>xf32></span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x10xf32></span><br><span class="line"> %<span class="number">1</span> = flow.tensor.reshape %<span class="number">0</span> : tensor<<span class="number">1</span>x10xf32> -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">2</span> = flow.tensor.empty : tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">3</span> = flow.dispatch @test_dispatch_0::@test_dispatch_0_generic_10[%c10](%<span class="number">1</span>, %_constant, %<span class="number">2</span>) : (tensor<<span class="number">10</span>xf32>, tensor<<span class="number">10</span>xf32>, tensor<<span class="number">10</span>xf32>) -> %<span class="number">2</span></span><br><span class="line"> %<span class="number">4</span> = flow.tensor.reshape %<span class="number">3</span> : tensor<<span class="number">10</span>xf32> -> tensor<<span class="number">1</span>x10xf32></span><br><span class="line"> %<span class="number">5</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">4</span> : tensor<<span class="number">1</span>x10xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">5</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>addCleanupPatterns</p></li><li><p>IREE::Stream::createConvertToStreamPass</p><p>将<code>IREE::Util</code>、<code>IREE::Flow</code>、<code>IREE::HAL</code>以及<code>std</code>dialect转换到<code>IREE::Stream</code> dialect。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">module</span> {</span><br><span class="line"> util.global <span class="keyword">private</span> @_constant {noinline} = dense<[<span class="number">0.000000e+00</span>, <span class="number">0.00999999977</span>, <span class="number">2.000000e-02</span>, <span class="number">3.000000e-02</span>, <span class="number">4.000000e-02</span>, <span class="number">5.000000e-02</span>, <span class="number">6.000000e-02</span>, <span class="number">7.000000e-02</span>, <span class="number">8.000000e-02</span>, <span class="number">9.000000e-02</span>]> : tensor<<span class="number">10</span>xf32></span><br><span class="line"> flow.executable <span class="keyword">private</span> @test_dispatch_0 {</span><br><span class="line"> flow.executable.<span class="keyword">export</span> <span class="keyword">public</span> @test_dispatch_0_generic_10 <span class="built_in">workgroups</span>(%arg0: index) -> (index, index, index) {</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0</span><br><span class="line"> flow.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> builtin.<span class="keyword">module</span> {</span><br><span class="line"> func.func @<span class="built_in">test_dispatch_0_generic_10</span>(%arg0: !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xf32>>, %arg2: !flow.dispatch.tensor<readwrite:tensor<<span class="number">10</span>xf32>>) {</span><br><span class="line"> %<span class="number">0</span> = flow.dispatch.tensor.load %arg0, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xf32>> -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">1</span> = flow.dispatch.tensor.load %arg1, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xf32>> -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">2</span> = flow.dispatch.tensor.load %arg2, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readwrite:tensor<<span class="number">10</span>xf32>> -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">10</span>xf32>, tensor<<span class="number">10</span>xf32>) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">10</span>xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">4</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> linalg.yield %<span class="number">4</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">3</span>, %arg2, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : tensor<<span class="number">10</span>xf32> -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">10</span>xf32>></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %_constant = util.global.load @_constant : tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x10xf32></span><br><span class="line"> %<span class="number">1</span> = flow.tensor.reshape %<span class="number">0</span> : tensor<<span class="number">1</span>x10xf32> -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">2</span> = flow.tensor.empty : tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">3</span> = flow.dispatch @test_dispatch_0::@test_dispatch_0_generic_10[%c10](%<span class="number">1</span>, %_constant, %<span class="number">2</span>) : (tensor<<span class="number">10</span>xf32>, tensor<<span class="number">10</span>xf32>, tensor<<span class="number">10</span>xf32>) -> %<span class="number">2</span></span><br><span class="line"> %<span class="number">4</span> = flow.tensor.reshape %<span class="number">3</span> : tensor<<span class="number">10</span>xf32> -> tensor<<span class="number">1</span>x10xf32></span><br><span class="line"> %<span class="number">5</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">4</span> : tensor<<span class="number">1</span>x10xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">5</span> : !hal.buffer_view</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换为</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">module</span> {</span><br><span class="line"> util.global <span class="keyword">private</span> @_constant : !stream.resource<constant></span><br><span class="line"> util.global <span class="keyword">private</span> @_constant__size : index</span><br><span class="line"> util.initializer {</span><br><span class="line"> %cst = stream.tensor.constant : tensor<<span class="number">10</span>xf32> in !stream.resource<constant> = dense<[<span class="number">0.000000e+00</span>, <span class="number">0.00999999977</span>, <span class="number">2.000000e-02</span>, <span class="number">3.000000e-02</span>, <span class="number">4.000000e-02</span>, <span class="number">5.000000e-02</span>, <span class="number">6.000000e-02</span>, <span class="number">7.000000e-02</span>, <span class="number">8.000000e-02</span>, <span class="number">9.000000e-02</span>]> : tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">0</span> = stream.resource.size %cst : !stream.resource<constant></span><br><span class="line"> util.global.store %cst, @_constant : !stream.resource<constant></span><br><span class="line"> util.global.store %<span class="number">0</span>, @_constant__size : index</span><br><span class="line"> util.initializer.<span class="keyword">return</span></span><br><span class="line"> }</span><br><span class="line"> stream.executable <span class="keyword">private</span> @test_dispatch_0 {</span><br><span class="line"> stream.executable.<span class="keyword">export</span> <span class="keyword">public</span> @test_dispatch_0_generic_10 <span class="built_in">workgroups</span>(%arg0: index) -> (index, index, index) {</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0</span><br><span class="line"> stream.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> builtin.<span class="keyword">module</span> {</span><br><span class="line"> func.func @<span class="built_in">test_dispatch_0_generic_10</span>(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding) {</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xf32>></span><br><span class="line"> %<span class="number">1</span> = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xf32>></span><br><span class="line"> %<span class="number">2</span> = stream.binding.subspan %arg2[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">10</span>xf32>></span><br><span class="line"> %<span class="number">3</span> = flow.dispatch.tensor.load %<span class="number">0</span>, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xf32>> -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">4</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xf32>> -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">5</span> = flow.dispatch.tensor.load %<span class="number">2</span>, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readwrite:tensor<<span class="number">10</span>xf32>> -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> %<span class="number">6</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">3</span>, %<span class="number">4</span> : tensor<<span class="number">10</span>xf32>, tensor<<span class="number">10</span>xf32>) <span class="built_in">outs</span>(%<span class="number">5</span> : tensor<<span class="number">10</span>xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">7</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> linalg.yield %<span class="number">7</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">6</span>, %<span class="number">2</span>, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : tensor<<span class="number">10</span>xf32> -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">10</span>xf32>></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %_constant = util.global.load @_constant : !stream.resource<constant></span><br><span class="line"> %_constant__size = util.global.load @_constant__size : index</span><br><span class="line"> %<span class="number">0</span> = stream.async.transfer %_constant : !stream.resource<constant>{%_constant__size} -> !stream.resource<*>{%_constant__size}</span><br><span class="line"> %c553648160_i32 = arith.constant <span class="number">553648160</span> : i32</span><br><span class="line"> %c1_i32 = arith.constant <span class="number">1</span> : i32</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c10_0 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> hal.buffer_view.assert<%arg0 : !hal.buffer_view> <span class="built_in">message</span>(<span class="string">"tensor"</span>) <span class="built_in">shape</span>([%c1, %c10_0]) <span class="built_in">type</span>(%c553648160_i32) <span class="built_in">encoding</span>(%c1_i32)</span><br><span class="line"> %<span class="number">1</span> = stream.tensor.<span class="keyword">sizeof</span> tensor<<span class="number">1</span>x10xf32> : index</span><br><span class="line"> %<span class="number">2</span> = stream.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x10xf32> in !stream.resource<external>{%<span class="number">1</span>}</span><br><span class="line"> %<span class="number">3</span> = stream.async.transfer %<span class="number">2</span> : !stream.resource<external>{%<span class="number">1</span>} -> !stream.resource<*>{%<span class="number">1</span>}</span><br><span class="line"> %<span class="number">4</span> = stream.tensor.<span class="keyword">sizeof</span> tensor<<span class="number">10</span>xf32> : index</span><br><span class="line"> %<span class="number">5</span> = stream.tensor.clone %<span class="number">3</span> : tensor<<span class="number">1</span>x10xf32> in !stream.resource<*>{%<span class="number">1</span>} -> tensor<<span class="number">10</span>xf32> in !stream.resource<*>{%<span class="number">4</span>}</span><br><span class="line"> %<span class="number">6</span> = stream.tensor.<span class="keyword">sizeof</span> tensor<<span class="number">10</span>xf32> : index</span><br><span class="line"> %empty = stream.tensor.empty : tensor<<span class="number">10</span>xf32> in !stream.resource<*>{%<span class="number">6</span>}</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">7</span> = stream.async.dispatch @test_dispatch_0::@test_dispatch_0_generic_10[%c10](%<span class="number">5</span>[%c0 to %<span class="number">4</span> <span class="keyword">for</span> %<span class="number">4</span>], %<span class="number">0</span>[%c0 to %_constant__size <span class="keyword">for</span> %_constant__size], %empty[%c0 to %<span class="number">6</span> <span class="keyword">for</span> %<span class="number">6</span>]) : (!stream.resource<*>{%<span class="number">4</span>}, !stream.resource<*>{%_constant__size}, !stream.resource<*>{%<span class="number">6</span>}) -> %empty{%<span class="number">6</span>}</span><br><span class="line"> %<span class="number">8</span> = stream.tensor.<span class="keyword">sizeof</span> tensor<<span class="number">1</span>x10xf32> : index</span><br><span class="line"> %<span class="number">9</span> = stream.tensor.clone %<span class="number">7</span> : tensor<<span class="number">10</span>xf32> in !stream.resource<*>{%<span class="number">6</span>} -> tensor<<span class="number">1</span>x10xf32> in !stream.resource<*>{%<span class="number">8</span>}</span><br><span class="line"> %<span class="number">10</span> = stream.async.transfer %<span class="number">9</span> : !stream.resource<*>{%<span class="number">8</span>} -> !stream.resource<external>{%<span class="number">8</span>}</span><br><span class="line"> %<span class="number">11</span> = stream.tensor.<span class="keyword">export</span> %<span class="number">10</span> : tensor<<span class="number">1</span>x10xf32> in !stream.resource<external>{%<span class="number">8</span>} -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">11</span> : !hal.buffer_view</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>可以看到除了<code>flow.executable</code>,module中<code>tensor</code>type都被转换成<code>stream.resource</code>和<code>index</code>,但<code>hal.buffer_view</code>type仍然被保留。初始值为tensor的<code>util.global</code>constant被转换为不带初始值的<code>stream.resource</code>和<code>index</code>,同时生成了一个<code>util.initializer</code>对<code>stream.resource</code>和<code>index</code>进行初始化。<code>util.global.load</code>被转换成<code>util.global.load</code> +<code>stream.async.transfer</code>,<code>hal.tensor.import</code>被转换成<code>stream.tensor.import</code>+<code>stream.async.transfer</code>,<code>hal.tensor.export</code>被转换为<code>stream.async.transfer</code>+<code>stream.tensor.export</code>,<code>flow.tensor.reshape</code>被转换成<code>stream.tensor.clone</code>,<code>flow.executable</code>转换为<code>stream.executable</code>,内部的<code>flow.executable.export</code>转换为<code>stream.executable.export</code>,内部的funcop的argument由<code>flow.dispatch.tensor</code>转换为<code>stream.binding</code>。</p></li><li><p>IREE::Stream::createVerifyLoweringToTensorsPass</p><p>检查program的合法性。</p></li><li><p>addCleanupPatterns</p></li><li><p>IREE::Util::createCombineInitializersPass</p><p>合并所有的<code>util.initializer</code> ops。</p></li></ul></li><li><p>buildStreamAsyncPassPipeline</p><ul><li><p>IREE::Stream::createEncodeHostTensorsPass</p><p>主要作用是将tensor的元素位宽(bit)扩充为2的幂大小,并按字节对齐。其中i1~i7转换为i8(1byte),i9~i15转换为i16 (2 bytes),i17~i31转换为i32 (4bytes),i33~i63转换为i64(8 bytes)。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">util.initializer {</span><br><span class="line"> %cst = stream.tensor.constant : tensor<<span class="number">10</span>xi4> in !stream.resource<constant> = dense<[<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>, <span class="number">7</span>, <span class="number">-8</span>, <span class="number">-7</span>]> : tensor<<span class="number">10</span>xi4></span><br><span class="line"> %<span class="number">0</span> = stream.resource.size %cst : !stream.resource<constant></span><br><span class="line"> util.global.store %cst, @_constant : !stream.resource<constant></span><br><span class="line"> util.global.store %<span class="number">0</span>, @_constant__size : index</span><br><span class="line"> util.initializer.<span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换为</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">util.initializer {</span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %cst = stream.async.constant : !stream.resource<constant>{%c10} = dense<[<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>, <span class="number">7</span>, <span class="number">8</span>, <span class="number">9</span>]> : tensor<<span class="number">10</span>xi8></span><br><span class="line"> util.global.store %cst, @_constant : !stream.resource<constant></span><br><span class="line"> util.global.store %c10, @_constant__size : index</span><br><span class="line"> util.initializer.<span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p><code>%cst</code>的类型从i4转成了i8,此外<code>stream.tensor.constant</code>转换成了<code>stream.async.constant</code>,<code>%0 = stream.resource.size %cst : !stream.resource<constant></code>直接被替换成了常量<code>%c10</code>。</p></li><li><p>IREE::Stream::createEncodeDeviceTensorsPass</p><p>和createEncodeHostTensorsPass作用一样,区别是createEncodeDeviceTensorsPass作用的是<code>stream.executable</code>中的op。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line">builtin.<span class="keyword">module</span> {</span><br><span class="line"> func.func @<span class="built_in">test_dispatch_0_generic_10</span>(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding) {</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xi4>></span><br><span class="line"> %<span class="number">1</span> = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xi4>></span><br><span class="line"> %<span class="number">2</span> = stream.binding.subspan %arg2[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">10</span>xi4>></span><br><span class="line"> %<span class="number">3</span> = flow.dispatch.tensor.load %<span class="number">0</span>, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xi4>> -> tensor<<span class="number">10</span>xi4></span><br><span class="line"> %<span class="number">4</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xi4>> -> tensor<<span class="number">10</span>xi4></span><br><span class="line"> %<span class="number">5</span> = flow.dispatch.tensor.load %<span class="number">2</span>, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readwrite:tensor<<span class="number">10</span>xi4>> -> tensor<<span class="number">10</span>xi4></span><br><span class="line"> %<span class="number">6</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">3</span>, %<span class="number">4</span> : tensor<<span class="number">10</span>xi4>, tensor<<span class="number">10</span>xi4>) <span class="built_in">outs</span>(%<span class="number">5</span> : tensor<<span class="number">10</span>xi4>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: i4, %in_0: i4, %out: i4):</span><br><span class="line"> %<span class="number">7</span> = arith.addi %in, %in_0 : i4</span><br><span class="line"> linalg.yield %<span class="number">7</span> : i4</span><br><span class="line"> } -> tensor<<span class="number">10</span>xi4></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">6</span>, %<span class="number">2</span>, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : tensor<<span class="number">10</span>xi4> -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">10</span>xi4>></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换为,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line">builtin.<span class="keyword">module</span> {</span><br><span class="line"> func.func @<span class="built_in">test_dispatch_0_generic_10</span>(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding) {</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xi8>></span><br><span class="line"> %<span class="number">1</span> = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xi8>></span><br><span class="line"> %<span class="number">2</span> = stream.binding.subspan %arg2[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">10</span>xi8>></span><br><span class="line"> %<span class="number">3</span> = flow.dispatch.tensor.load %<span class="number">0</span>, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xi8>> -> tensor<<span class="number">10</span>xi8></span><br><span class="line"> %<span class="number">4</span> = arith.trunci %<span class="number">3</span> : tensor<<span class="number">10</span>xi8> to tensor<<span class="number">10</span>xi4></span><br><span class="line"> %<span class="number">5</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>xi8>> -> tensor<<span class="number">10</span>xi8></span><br><span class="line"> %<span class="number">6</span> = arith.trunci %<span class="number">5</span> : tensor<<span class="number">10</span>xi8> to tensor<<span class="number">10</span>xi4></span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tensor.load %<span class="number">2</span>, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readwrite:tensor<<span class="number">10</span>xi8>> -> tensor<<span class="number">10</span>xi8></span><br><span class="line"> %<span class="number">8</span> = arith.trunci %<span class="number">7</span> : tensor<<span class="number">10</span>xi8> to tensor<<span class="number">10</span>xi4></span><br><span class="line"> %<span class="number">9</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">4</span>, %<span class="number">6</span> : tensor<<span class="number">10</span>xi4>, tensor<<span class="number">10</span>xi4>) <span class="built_in">outs</span>(%<span class="number">8</span> : tensor<<span class="number">10</span>xi4>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: i4, %in_0: i4, %out: i4):</span><br><span class="line"> %<span class="number">11</span> = arith.addi %in, %in_0 : i4</span><br><span class="line"> linalg.yield %<span class="number">11</span> : i4</span><br><span class="line"> } -> tensor<<span class="number">10</span>xi4></span><br><span class="line"> %<span class="number">10</span> = arith.extui %<span class="number">9</span> : tensor<<span class="number">10</span>xi4> to tensor<<span class="number">10</span>xi8></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">10</span>, %<span class="number">2</span>, offsets = [<span class="number">0</span>], sizes = [<span class="number">10</span>], strides = [<span class="number">1</span>] : tensor<<span class="number">10</span>xi8> -> !flow.dispatch.tensor<readwrite:tensor<<span class="number">10</span>xi8>></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>可以看到<code>stream.binding.subspan</code>的resulttype从i4转换成了i8,并且在<code>flow.dispatch.tensor.load</code>之后插入了一个<code>arith.trunci</code>,将i8截断为i4,进而参与<code>linalg.generic</code>中的计算。</p></li><li><p>IREE::Stream::createMaterializeBuiltinsPass</p></li><li><p>addCleanupPatterns</p></li><li><p>IREE::Stream::createMaterializeCopyOnWritePass</p><p>写入时插入一次拷贝,以更有效地支持inplace更新,并且确保正确的执行语义。</p></li><li><p>IREE::Stream::createElideAsyncCopiesPass</p><p>消除MaterializeCopyOnWritePass中插入的不必要的拷贝。</p></li><li><p>mlir::createCanonicalizerPass</p></li><li><p>IREE::Stream::createEmplaceAllocationsPass</p><p>尝试消除<code>stream.async.dispatch</code>后的<code>stream.async.update</code>op。当<code>stream.async.dispatch</code>的结果没有绑定一个value时,就可以把<code>stream.async.update</code>的target绑定到<code>stream.async.dispatch</code>的结果,使得<code>stream.async.dispatch</code>直接把计算结果更新到target。</p></li><li><p>IREE::Stream::createRefineUsagePass</p><p>确定每个<code>stream.resource</code>的生命期,推导<code>stream.resource</code>的类型。<code>stream.resource</code>类型包括:</p><ul><li>Unknown: <code>stream.resource<*></code></li><li>External:<code>stream.resource<external></code>由外部程序管理的内存</li><li>Staging:<code>stream.resource<staging></code>用于上传/下载的暂存缓冲区</li><li>Transient:<code>stream.resource<transient></code>跨stream的一段临时值</li><li>Variable:<code>stream.resource<variable></code>跨stream的一段持续值</li><li>Constant:<code>stream.resource<constant></code>整个程序中持续存在的立即值(常量)。</li></ul><p>除此之外还消除了冗余的<code>stream.async.transfer</code>。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c40 = arith.constant <span class="number">40</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %c553648160_i32 = arith.constant <span class="number">553648160</span> : i32</span><br><span class="line"> %c1_i32 = arith.constant <span class="number">1</span> : i32</span><br><span class="line"> hal.buffer_view.assert<%arg0 : !hal.buffer_view> <span class="built_in">message</span>(<span class="string">"tensor"</span>) <span class="built_in">shape</span>([%c10]) <span class="built_in">type</span>(%c553648160_i32) <span class="built_in">encoding</span>(%c1_i32)</span><br><span class="line"> %<span class="number">0</span> = stream.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">10</span>xf32> in !stream.resource<external>{%c40}</span><br><span class="line"> %<span class="number">1</span> = stream.async.transfer %<span class="number">0</span> : !stream.resource<external>{%c40} -> !stream.resource<*>{%c40}</span><br><span class="line"> %<span class="number">2</span> = stream.async.dispatch @test_dispatch_0::@test_dispatch_0_generic_10[%c10](%<span class="number">1</span>[%c0 to %c40 <span class="keyword">for</span> %c40]) : (!stream.resource<*>{%c40}) -> !stream.resource<*>{%c40}</span><br><span class="line"> %<span class="number">3</span> = stream.async.transfer %<span class="number">2</span> : !stream.resource<*>{%c40} -> !stream.resource<external>{%c40}</span><br><span class="line"> %<span class="number">4</span> = stream.tensor.<span class="keyword">export</span> %<span class="number">3</span> : tensor<<span class="number">10</span>xf32> in !stream.resource<external>{%c40} -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">4</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换为</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c40 = arith.constant <span class="number">40</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %c553648160_i32 = arith.constant <span class="number">553648160</span> : i32</span><br><span class="line"> %c1_i32 = arith.constant <span class="number">1</span> : i32</span><br><span class="line"> hal.buffer_view.assert<%arg0 : !hal.buffer_view> <span class="built_in">message</span>(<span class="string">"tensor"</span>) <span class="built_in">shape</span>([%c10]) <span class="built_in">type</span>(%c553648160_i32) <span class="built_in">encoding</span>(%c1_i32)</span><br><span class="line"> %<span class="number">0</span> = stream.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">10</span>xf32> in !stream.resource<external>{%c40}</span><br><span class="line"> %<span class="number">1</span> = stream.async.dispatch @test_dispatch_0::@test_dispatch_0_generic_10[%c10](%<span class="number">0</span>[%c0 to %c40 <span class="keyword">for</span> %c40]) : (!stream.resource<external>{%c40}) -> !stream.resource<external>{%c40}</span><br><span class="line"> %<span class="number">2</span> = stream.tensor.<span class="keyword">export</span> %<span class="number">1</span> : tensor<<span class="number">10</span>xf32> in !stream.resource<external>{%c40} -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">2</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>可以看到<code>!stream.resource<*>{ %c40}</code>被推导为<code>!stream.resource<external>{ %c40}</code>,并且有两处<code>stream.async.transfer</code>被删除了。</p></li><li><p>addCleanupPatterns</p></li><li><p>IREE::Stream::createScheduleExecutionPass</p><p>根据启发式算法将每个callable(包括<code>util.initializer</code>)划分成多个part进行调度,每个part独立构成一个<code>stream.async.execute</code>,并且每个<code>stream.async.execute</code>后面都跟了一个<code>stream.timepoint.await</code>操作用于同步执行结果。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c40 = arith.constant <span class="number">40</span> : index</span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %c553648160_i32 = arith.constant <span class="number">553648160</span> : i32</span><br><span class="line"> %c1_i32 = arith.constant <span class="number">1</span> : i32</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %_constant = util.global.load @_constant : !stream.resource<constant></span><br><span class="line"> hal.buffer_view.assert<%arg0 : !hal.buffer_view> <span class="built_in">message</span>(<span class="string">"tensor"</span>) <span class="built_in">shape</span>([%c1, %c10]) <span class="built_in">type</span>(%c553648160_i32) <span class="built_in">encoding</span>(%c1_i32)</span><br><span class="line"> %<span class="number">0</span> = stream.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x10xf32> in !stream.resource<external>{%c40}</span><br><span class="line"> %<span class="number">1</span> = stream.async.alloca : !stream.resource<external>{%c40}</span><br><span class="line"> %<span class="number">2</span> = stream.async.dispatch @test_dispatch_0::@test_dispatch_0_generic_10[%c10](%<span class="number">0</span>[%c0 to %c40 <span class="keyword">for</span> %c40], %_constant[%c0 to %c40 <span class="keyword">for</span> %c40], %<span class="number">1</span>[%c0 to %c40 <span class="keyword">for</span> %c40]) : (!stream.resource<external>{%c40}, !stream.resource<constant>{%c40}, !stream.resource<external>{%c40}) -> %<span class="number">1</span>{%c40}</span><br><span class="line"> %<span class="number">3</span> = stream.tensor.<span class="keyword">export</span> %<span class="number">2</span> : tensor<<span class="number">1</span>x10xf32> in !stream.resource<external>{%c40} -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">3</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c40 = arith.constant <span class="number">40</span> : index</span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %c553648160_i32 = arith.constant <span class="number">553648160</span> : i32</span><br><span class="line"> %c1_i32 = arith.constant <span class="number">1</span> : i32</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %_constant = util.global.load @_constant : !stream.resource<constant></span><br><span class="line"> hal.buffer_view.assert<%arg0 : !hal.buffer_view> <span class="built_in">message</span>(<span class="string">"tensor"</span>) <span class="built_in">shape</span>([%c1, %c10]) <span class="built_in">type</span>(%c553648160_i32) <span class="built_in">encoding</span>(%c1_i32)</span><br><span class="line"> %<span class="number">0</span> = stream.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x10xf32> in !stream.resource<external>{%c40}</span><br><span class="line"> %results, %result_timepoint = stream.async.execute <span class="built_in">with</span>(%<span class="number">0</span> as %arg1: !stream.resource<external>{%c40}, %_constant as %arg2: !stream.resource<constant>{%c40}) -> !stream.resource<external>{%c40} {</span><br><span class="line"> %<span class="number">3</span> = stream.async.alloca : !stream.resource<external>{%c40}</span><br><span class="line"> %<span class="number">4</span> = stream.async.dispatch @test_dispatch_0::@test_dispatch_0_generic_10[%c10](%arg1[%c0 to %c40 <span class="keyword">for</span> %c40], %arg2[%c0 to %c40 <span class="keyword">for</span> %c40], %<span class="number">3</span>[%c0 to %c40 <span class="keyword">for</span> %c40]) : (!stream.resource<external>{%c40}, !stream.resource<constant>{%c40}, !stream.resource<external>{%c40}) -> %<span class="number">3</span>{%c40}</span><br><span class="line"> stream.yield %<span class="number">4</span> : !stream.resource<external>{%c40}</span><br><span class="line"> } => !stream.timepoint</span><br><span class="line"> %<span class="number">1</span> = stream.timepoint.await %result_timepoint => %results : !stream.resource<external>{%c40}</span><br><span class="line"> %<span class="number">2</span> = stream.tensor.<span class="keyword">export</span> %<span class="number">1</span> : tensor<<span class="number">1</span>x10xf32> in !stream.resource<external>{%c40} -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">2</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>注意:该示例中只有一个part。</p></li><li><p>IREE::Stream::createScheduleConcurrencyPass</p><p>继续将<code>stream.async.execute</code>划分为多个并行调度区,每个并行调度区构成一个<code>stream.async.concurrent</code>。</p></li><li><p>IREE::Stream::createPropagateTimepointsPass</p><p>给<code>stream.resource</code> 绑定一个<code>stream.timepoint</code>,在代码中用<code>stream.resource + stream.timepoint</code>的pair替换原来的<code>stream.resource</code>,并在需要的地方插入await。</p><ul><li><p><code>util.global</code></p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">util.global <span class="keyword">private</span> @_constant : !stream.resource<constant></span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">util.global <span class="keyword">private</span> <span class="keyword">mutable</span> @_constant__timepoint = <span class="meta">#stream.timepoint<span class="string"><immediate></span> : !stream.timepoint</span></span><br><span class="line">util.global <span class="keyword">private</span> @_constant : !stream.resource<constant></span><br></pre></td></tr></table></figure></li><li><p><code>util.global.load</code></p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">%_constant = util.global.load @_constant : !stream.resource<constant></span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">%_constant__timepoint = util.global.load @_constant__timepoint : !stream.timepoint</span><br><span class="line">%_constant = util.global.load @_constant : !stream.resource<constant></span><br><span class="line">%<span class="number">0</span> = stream.timepoint.await %_constant__timepoint => %_constant : !stream.resource<constant>{%c40}</span><br></pre></td></tr></table></figure></li><li><p><code>util.global.store</code></p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">util.global.store %<span class="number">0</span>, @_constant : !stream.resource<constant></span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">util.global.store %result_timepoint, @_constant__timepoint : !stream.timepoint</span><br><span class="line">util.global.store %results, @_constant : !stream.resource<constant></span><br></pre></td></tr></table></figure></li><li><p><code>func.func</code></p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%<span class="number">0</span>: !stream.resource) {</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%t: !stream.timepoint, %<span class="number">0</span>: !stream.resource) {</span><br><span class="line"> %<span class="number">1</span> = stream.timepoint.await %t, %<span class="number">0</span></span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p><code>call</code></p><p>由于func内部已经插入了await,因此call之前的冗余await可以删除,call之后需要再插入一个func返回值的await。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">1</span> = stream.timepoint.await %t, %<span class="number">0</span></span><br><span class="line">%r = call @<span class="built_in">foo</span>(%<span class="number">1</span>)</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">%rt, %r = call @<span class="built_in">foo</span>(%t, %<span class="number">0</span>)</span><br><span class="line">stream.timepoint.await %rt, %t</span><br></pre></td></tr></table></figure></li><li><p><code>return</code></p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">1</span> = stream.timepoint.await %t, %<span class="number">0</span></span><br><span class="line"><span class="keyword">return</span> %<span class="number">1</span></span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">return</span> %t, %<span class="number">0</span></span><br></pre></td></tr></table></figure></li><li><p><code>branch</code></p><p>将参数的await挪到branch里面。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">1</span> = stream.timepoint.await %t, %<span class="number">0</span></span><br><span class="line">br ^<span class="built_in">bb1</span>(%<span class="number">1</span>)</span><br><span class="line"></span><br><span class="line">^<span class="built_in">bb1</span>(%b):</span><br><span class="line"> ...</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">br ^<span class="built_in">bb1</span>(%t, %<span class="number">0</span>)</span><br><span class="line">^<span class="built_in">bb1</span>(%a, %b):</span><br><span class="line"> %<span class="number">1</span> = stream.timepoint.await %a, %b</span><br></pre></td></tr></table></figure></li><li><p><code>stream.async.execute</code></p><p>为每个未绑定<code>stream.timepoint</code>的输入参数绑定一个<code>stream.timepoint</code>,并在<code>stream.async.execute</code>之前计算参数的最大timepoint,<code>stream.async.execute</code>则await这个最大timepoint。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">%results, %result_timepoint = stream.async.execute <span class="built_in">with</span>(%<span class="number">0</span> as %arg1: !stream.resource<external>{%c40}, %_constant as %arg2: !stream.resource<constant>{%c40}) -> !stream.resource<external>{%c40} {</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">3</span> = stream.timepoint.join <span class="built_in">max</span>(%<span class="number">2</span>, %_constant__timepoint) => !stream.timepoint</span><br><span class="line">%results, %result_timepoint = stream.async.execute <span class="built_in">await</span>(%<span class="number">3</span>) => <span class="built_in">with</span>(%<span class="number">1</span> as %arg1: !stream.resource<external>{%c40}, %_constant as %arg2: !stream.resource<constant>{%c40}) -> !stream.resource<external>{%c40} {</span><br><span class="line"> ...</span><br><span class="line"> }</span><br></pre></td></tr></table></figure></li></ul></li><li><p>addCleanupPatterns</p></li><li><p>IREE::Stream::createVerifyLoweringToAsyncPass</p><p>验证LoweringToAsyncPass阶段program的合法性。</p></li></ul></li><li><p>buildStreamCmdPassPipeline</p><ul><li><p>IREE::Stream::createScheduleAllocationPas</p><ul><li><p>首先将所有常量op聚合成一个<code>stream.resource.constants</code>,并移出该region,<code>stream.resource.constants</code>的结果会被append到该region的输入参数中(原本直接yield的常量除外)。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">%results, %result_timepoint = stream.async.execute <span class="built_in">with</span>() -> !stream.resource<constant>{%c40} {</span><br><span class="line"> %cst = stream.async.constant : !stream.resource<constant>{%c40} = dense<[<span class="number">0.000000e+00</span>, <span class="number">0.00999999977</span>, <span class="number">2.000000e-02</span>, <span class="number">3.000000e-02</span>, <span class="number">4.000000e-02</span>, <span class="number">5.000000e-02</span>, <span class="number">6.000000e-02</span>, <span class="number">7.000000e-02</span>, <span class="number">8.000000e-02</span>, <span class="number">9.000000e-02</span>]> : tensor<<span class="number">10</span>xf32></span><br><span class="line"> stream.yield %cst : !stream.resource<constant>{%c40}</span><br><span class="line">} => !stream.timepoint</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">%results, %result_timepoint = stream.resource.constants :</span><br><span class="line"> !stream.resource<constant>{%c40} = dense<[<span class="number">0.000000e+00</span>, <span class="number">0.00999999977</span>, <span class="number">2.000000e-02</span>, <span class="number">3.000000e-02</span>, <span class="number">4.000000e-02</span>, <span class="number">5.000000e-02</span>, <span class="number">6.000000e-02</span>, <span class="number">7.000000e-02</span>, <span class="number">8.000000e-02</span>, <span class="number">9.000000e-02</span>]> : tensor<<span class="number">10</span>xf32></span><br><span class="line"> => !stream.timepoint</span><br><span class="line">%<span class="number">0</span> = stream.cmd.execute <span class="built_in">with</span>() {</span><br><span class="line">} => !stream.timepoint</span><br><span class="line">%<span class="number">1</span> = stream.timepoint.join <span class="built_in">max</span>(%result_timepoint, %<span class="number">0</span>) => !stream.timepoint</span><br></pre></td></tr></table></figure></li><li><p>分析<code>stream.async.execute</code>region中resource的类型和他们之间的alias关系,按照resource的类型统一分配空间。对于没有被Tied到输入(即非inplace)的results,会统一在region外面由<code>stream.resource.alloc</code>申请一段external空间,region再通过Tied的方式消费alloc的结果。对于中间临时的resource,经过<code>stream.resource.pack</code>计算需要分配的空间大小后统一由<code>stream.resource.alloca</code>申请一段transient空间,并会在region后面插入<code>stream.resource.dealloca</code>释放申请的临时空间。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">predict</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c8 = arith.constant <span class="number">8</span> : index</span><br><span class="line"> %c40 = arith.constant <span class="number">40</span> : index</span><br><span class="line"> %c4 = arith.constant <span class="number">4</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %c553648160_i32 = arith.constant <span class="number">553648160</span> : i32</span><br><span class="line"> %c1_i32 = arith.constant <span class="number">1</span> : i32</span><br><span class="line"> %c2 = arith.constant <span class="number">2</span> : index</span><br><span class="line"> hal.buffer_view.assert<%arg0 : !hal.buffer_view> <span class="built_in">message</span>(<span class="string">"tensor"</span>) <span class="built_in">shape</span>([%c1, %c2]) <span class="built_in">type</span>(%c553648160_i32) <span class="built_in">encoding</span>(%c1_i32)</span><br><span class="line"> %<span class="number">0</span> = stream.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x2xf32> in !stream.resource<external>{%c8}</span><br><span class="line"> <span class="comment">// stream.async.execute</span></span><br><span class="line"> %results, %result_timepoint = stream.async.execute <span class="built_in">with</span>(%<span class="number">0</span> as %arg1: !stream.resource<external>{%c8}) -> !stream.resource<external>{%c40} {</span><br><span class="line"> %<span class="number">3</span> = stream.async.dispatch @predict_dispatch_0::@predict_dispatch_0_matmul_1x10x2[%c1, %c10](%arg1[%c0 to %c8 <span class="keyword">for</span> %c8]) : (!stream.resource<external>{%c8}) -> !stream.resource<transient>{%c40}</span><br><span class="line"> %<span class="number">4</span> = stream.async.dispatch @predict_dispatch_1::@predict_dispatch_1_generic_10[%c1](%<span class="number">3</span>[%c0 to %c40 <span class="keyword">for</span> %c40]) : (!stream.resource<transient>{%c40}) -> !stream.resource<transient>{%c4}</span><br><span class="line"> %<span class="number">5</span> = stream.async.dispatch @predict_dispatch_2::@predict_dispatch_2_generic_1x10[%c1, %c10](%<span class="number">3</span>[%c0 to %c40 <span class="keyword">for</span> %c40], %<span class="number">4</span>[%c0 to %c4 <span class="keyword">for</span> %c4]) : (!stream.resource<transient>{%c40}, !stream.resource<transient>{%c4}) -> !stream.resource<transient>{%c40}</span><br><span class="line"> %<span class="number">6</span> = stream.async.dispatch @predict_dispatch_3::@predict_dispatch_3_generic_10[%c1](%<span class="number">5</span>[%c0 to %c40 <span class="keyword">for</span> %c40]) : (!stream.resource<transient>{%c40}) -> !stream.resource<transient>{%c4}</span><br><span class="line"> %<span class="number">7</span> = stream.async.dispatch @predict_dispatch_4::@predict_dispatch_4_generic_1x10[%c1, %c10](%<span class="number">5</span>[%c0 to %c40 <span class="keyword">for</span> %c40], %<span class="number">6</span>[%c0 to %c4 <span class="keyword">for</span> %c4]) : (!stream.resource<transient>{%c40}, !stream.resource<transient>{%c4}) -> !stream.resource<external>{%c40}</span><br><span class="line"> stream.yield %<span class="number">7</span> : !stream.resource<external>{%c40}</span><br><span class="line"> } => !stream.timepoint</span><br><span class="line"> %<span class="number">1</span> = stream.timepoint.await %result_timepoint => %results : !stream.resource<external>{%c40}</span><br><span class="line"> %<span class="number">2</span> = stream.tensor.<span class="keyword">export</span> %<span class="number">1</span> : tensor<<span class="number">1</span>x10xf32> in !stream.resource<external>{%c40} -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">2</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">predict</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c8 = arith.constant <span class="number">8</span> : index</span><br><span class="line"> %c40 = arith.constant <span class="number">40</span> : index</span><br><span class="line"> %c4 = arith.constant <span class="number">4</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %c553648160_i32 = arith.constant <span class="number">553648160</span> : i32</span><br><span class="line"> %c1_i32 = arith.constant <span class="number">1</span> : i32</span><br><span class="line"> %c2 = arith.constant <span class="number">2</span> : index</span><br><span class="line"> hal.buffer_view.assert<%arg0 : !hal.buffer_view> <span class="built_in">message</span>(<span class="string">"tensor"</span>) <span class="built_in">shape</span>([%c1, %c2]) <span class="built_in">type</span>(%c553648160_i32) <span class="built_in">encoding</span>(%c1_i32)</span><br><span class="line"> %<span class="number">0</span> = stream.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x2xf32> in !stream.resource<external>{%c8}</span><br><span class="line"> %c0_0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> <span class="comment">// 申请输出resource的空间</span></span><br><span class="line"> %<span class="number">1</span> = stream.resource.alloc uninitialized : !stream.resource<external>{%c40}</span><br><span class="line"> <span class="comment">// 计算临时resource所需要的空间大小</span></span><br><span class="line"> %<span class="number">2</span>:<span class="number">5</span> = stream.resource.pack <span class="built_in">slices</span>({</span><br><span class="line"> [<span class="number">0</span>, <span class="number">2</span>] = %c40, <span class="comment">// [0, 2]是某个resource的lifetime,%40是resource size</span></span><br><span class="line"> [<span class="number">1</span>, <span class="number">2</span>] = %c4,</span><br><span class="line"> [<span class="number">2</span>, <span class="number">4</span>] = %c40,</span><br><span class="line"> [<span class="number">3</span>, <span class="number">4</span>] = %c4</span><br><span class="line"> }) : index</span><br><span class="line"> <span class="comment">// 申请临时resource的空间</span></span><br><span class="line"> %result, %result_timepoint = stream.resource.alloca uninitialized : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>} => !stream.timepoint</span><br><span class="line"> %<span class="number">3</span> = stream.cmd.execute <span class="built_in">await</span>(%result_timepoint) => <span class="built_in">with</span>(%<span class="number">0</span> as %arg1: !stream.resource<external>{%c8}, %<span class="number">1</span> as %arg2: !stream.resource<external>{%c40}, %result as %arg3: !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>}) {</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_0::@predict_dispatch_0_matmul_1x10x2[%c1, %c10] {</span><br><span class="line"> ro %arg1[%c0 <span class="keyword">for</span> %c8] : !stream.resource<external>{%c8},</span><br><span class="line"> wo %arg3[%<span class="number">2</span>#<span class="number">1</span> <span class="keyword">for</span> %c40] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>}</span><br><span class="line"> }</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_1::@predict_dispatch_1_generic_10[%c1] {</span><br><span class="line"> ro %arg3[%<span class="number">2</span>#<span class="number">1</span> <span class="keyword">for</span> %c40] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>},</span><br><span class="line"> wo %arg3[%<span class="number">2</span>#<span class="number">2</span> <span class="keyword">for</span> %c4] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>}</span><br><span class="line"> }</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_2::@predict_dispatch_2_generic_1x10[%c1, %c10] {</span><br><span class="line"> ro %arg3[%<span class="number">2</span>#<span class="number">1</span> <span class="keyword">for</span> %c40] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>},</span><br><span class="line"> ro %arg3[%<span class="number">2</span>#<span class="number">2</span> <span class="keyword">for</span> %c4] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>},</span><br><span class="line"> wo %arg3[%<span class="number">2</span>#<span class="number">3</span> <span class="keyword">for</span> %c40] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>}</span><br><span class="line"> }</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_3::@predict_dispatch_3_generic_10[%c1] {</span><br><span class="line"> ro %arg3[%<span class="number">2</span>#<span class="number">3</span> <span class="keyword">for</span> %c40] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>},</span><br><span class="line"> wo %arg3[%<span class="number">2</span>#<span class="number">4</span> <span class="keyword">for</span> %c4] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>}</span><br><span class="line"> }</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_4::@predict_dispatch_4_generic_1x10[%c1, %c10] {</span><br><span class="line"> ro %arg3[%<span class="number">2</span>#<span class="number">3</span> <span class="keyword">for</span> %c40] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>},</span><br><span class="line"> ro %arg3[%<span class="number">2</span>#<span class="number">4</span> <span class="keyword">for</span> %c4] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>},</span><br><span class="line"> wo %arg2[%c0_0 <span class="keyword">for</span> %c40] : !stream.resource<external>{%c40}</span><br><span class="line"> }</span><br><span class="line"> } => !stream.timepoint</span><br><span class="line"> <span class="comment">// 释放申请的临时空间</span></span><br><span class="line"> %<span class="number">4</span> = stream.resource.dealloca <span class="built_in">await</span>(%<span class="number">3</span>) => %result : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>} => !stream.timepoint</span><br><span class="line"> %<span class="number">5</span> = stream.timepoint.join <span class="built_in">max</span>(%<span class="number">4</span>, %<span class="number">3</span>) => !stream.timepoint</span><br><span class="line"> %<span class="number">6</span> = stream.timepoint.await %<span class="number">5</span> => %<span class="number">1</span> : !stream.resource<external>{%c40}</span><br><span class="line"> %<span class="number">7</span> = stream.tensor.<span class="keyword">export</span> %<span class="number">6</span> : tensor<<span class="number">1</span>x10xf32> in !stream.resource<external>{%c40} -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">7</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li></ul></li><li><p>IREE::Stream::createPackConstantsPass</p><p>将<code>stream.resource.constants</code>的结果根据lifetime类型分成Constant和Variable两种,每一种都替换成一个<code>util.buffer.constant</code>。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line">util.initializer {</span><br><span class="line"> %c40 = arith.constant <span class="number">40</span> : index</span><br><span class="line"> %results, %result_timepoint = stream.resource.constants :</span><br><span class="line"> !stream.resource<constant>{%c40} = dense<[<span class="number">0.000000e+00</span>, <span class="number">0.00999999977</span>, <span class="number">2.000000e-02</span>, <span class="number">3.000000e-02</span>, <span class="number">4.000000e-02</span>, <span class="number">5.000000e-02</span>, <span class="number">6.000000e-02</span>, <span class="number">7.000000e-02</span>, <span class="number">8.000000e-02</span>, <span class="number">9.000000e-02</span>]> : tensor<<span class="number">10</span>xf32></span><br><span class="line"> => !stream.timepoint</span><br><span class="line"> %<span class="number">0</span> = stream.cmd.execute <span class="built_in">with</span>() {</span><br><span class="line"> } => !stream.timepoint</span><br><span class="line"> %<span class="number">1</span> = stream.timepoint.join <span class="built_in">max</span>(%result_timepoint, %<span class="number">0</span>) => !stream.timepoint</span><br><span class="line"> util.global.store %results, @_constant : !stream.resource<constant></span><br><span class="line"> util.global.store %<span class="number">1</span>, @_constant__timepoint : !stream.timepoint</span><br><span class="line"> util.initializer.<span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br></pre></td><td class="code"><pre><span class="line">util.initializer {</span><br><span class="line"> %c40 = arith.constant <span class="number">40</span> : index</span><br><span class="line"> %buffer_cst = util.buffer.constant {alignment = <span class="number">64</span> : index} : !util.buffer = <span class="meta">#util.composite<64xi8, [</span></span><br><span class="line"> dense<[<span class="number">0.000000e+00</span>, <span class="number">0.00999999977</span>, <span class="number">2.000000e-02</span>, <span class="number">3.000000e-02</span>, <span class="number">4.000000e-02</span>, <span class="number">5.000000e-02</span>, <span class="number">6.000000e-02</span>, <span class="number">7.000000e-02</span>, <span class="number">8.000000e-02</span>, <span class="number">9.000000e-02</span>]> : tensor<<span class="number">10</span>xf32>,</span><br><span class="line"> dense<<span class="number">0</span>> : vector<<span class="number">24</span>xi8>, <span class="comment">// 填充的无用数据</span></span><br><span class="line">]></span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %c64 = arith.constant <span class="number">64</span> : index</span><br><span class="line"> <span class="comment">// 尝试将buffer映射为target (!stream.resource<constant>)</span></span><br><span class="line"> %did_map, %result = stream.resource.try_map %buffer_cst[%c0] : !util.buffer -> i1, !stream.resource<constant>{%c64}</span><br><span class="line"> %<span class="number">0</span>:<span class="number">2</span> = scf.<span class="keyword">if</span> %did_map -> (!stream.resource<constant>, !stream.timepoint) {</span><br><span class="line"> <span class="comment">// 如果可以映射,则直接返回映射的结果(!stream.resource<constant>)</span></span><br><span class="line"> %<span class="number">4</span> = stream.timepoint.immediate => !stream.timepoint</span><br><span class="line"> scf.yield %result, %<span class="number">4</span> : !stream.resource<constant>, !stream.timepoint</span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> <span class="comment">// 如果不能映射,需要先将buffer映射为缓冲区(stage),然后申请一段新的空间并从缓冲区拷贝数据(copy)。</span></span><br><span class="line"> <span class="comment">// 如果lifetime类型是Variable,则不需要try_map,直接走该分支(stage + copy)的实现。</span></span><br><span class="line"> %<span class="number">4</span> = stream.resource.map %buffer_cst[%c0] : !util.buffer -> !stream.resource<staging>{%c64}</span><br><span class="line"> %<span class="number">5</span> = stream.resource.alloc uninitialized : !stream.resource<constant>{%c64}</span><br><span class="line"> %<span class="number">6</span> = stream.cmd.execute <span class="built_in">with</span>(%<span class="number">4</span> as %arg0: !stream.resource<staging>{%c64}, %<span class="number">5</span> as %arg1: !stream.resource<constant>{%c64}) {</span><br><span class="line"> stream.cmd.copy %arg0[%c0], %arg1[%c0], %c64 : !stream.resource<staging>{%c64} -> !stream.resource<constant>{%c64}</span><br><span class="line"> } => !stream.timepoint</span><br><span class="line"> scf.yield %<span class="number">5</span>, %<span class="number">6</span> : !stream.resource<constant>, !stream.timepoint</span><br><span class="line"> }</span><br><span class="line"> %<span class="number">1</span> = stream.resource.subview %<span class="number">0</span>#<span class="number">0</span>[%c0] : !stream.resource<constant>{%c64} -> !stream.resource<constant>{%c40}</span><br><span class="line"> %<span class="number">2</span> = stream.cmd.execute <span class="built_in">with</span>() {</span><br><span class="line"> } => !stream.timepoint</span><br><span class="line"> %<span class="number">3</span> = stream.timepoint.join <span class="built_in">max</span>(%<span class="number">0</span>#<span class="number">1</span>, %<span class="number">2</span>) => !stream.timepoint</span><br><span class="line"> util.global.store %<span class="number">1</span>, @_constant : !stream.resource<constant></span><br><span class="line"> util.global.store %<span class="number">3</span>, @_constant__timepoint : !stream.timepoint</span><br><span class="line"> util.initializer.<span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>IREE::Stream::createPackAllocationsPass</p><p>将包含多个resource的<code>stream.resource.alloc</code> 转换成<code>stream.resource.pack + stream.resource.alloc</code>,并通过<code>stream.resource.subview</code>获取每一个resource。</p></li><li><p>IREE::Stream::createLayoutSlicesPass</p><p>将<code>stream.resource.pack</code>转化为具体的内存复用算法计算过程。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">predict</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c8 = arith.constant <span class="number">8</span> : index</span><br><span class="line"> %c40 = arith.constant <span class="number">40</span> : index</span><br><span class="line"> %c4 = arith.constant <span class="number">4</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %c553648160_i32 = arith.constant <span class="number">553648160</span> : i32</span><br><span class="line"> %c1_i32 = arith.constant <span class="number">1</span> : i32</span><br><span class="line"> %c2 = arith.constant <span class="number">2</span> : index</span><br><span class="line"> hal.buffer_view.assert<%arg0 : !hal.buffer_view> <span class="built_in">message</span>(<span class="string">"tensor"</span>) <span class="built_in">shape</span>([%c1, %c2]) <span class="built_in">type</span>(%c553648160_i32) <span class="built_in">encoding</span>(%c1_i32)</span><br><span class="line"> %<span class="number">0</span> = stream.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x2xf32> in !stream.resource<external>{%c8}</span><br><span class="line"> %c0_0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> <span class="comment">// 申请输出resource的空间</span></span><br><span class="line"> %<span class="number">1</span> = stream.resource.alloc uninitialized : !stream.resource<external>{%c40}</span><br><span class="line"> <span class="comment">// 计算临时resource所需要的空间大小</span></span><br><span class="line"> %<span class="number">2</span>:<span class="number">5</span> = stream.resource.pack <span class="built_in">slices</span>({</span><br><span class="line"> [<span class="number">0</span>, <span class="number">2</span>] = %c40, <span class="comment">// [0, 2]是某个resource的lifetime,%40是resource size</span></span><br><span class="line"> [<span class="number">1</span>, <span class="number">2</span>] = %c4,</span><br><span class="line"> [<span class="number">2</span>, <span class="number">4</span>] = %c40,</span><br><span class="line"> [<span class="number">3</span>, <span class="number">4</span>] = %c4</span><br><span class="line"> }) : index</span><br><span class="line"> <span class="comment">// 申请临时resource的空间</span></span><br><span class="line"> %result, %result_timepoint = stream.resource.alloca uninitialized : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>} => !stream.timepoint</span><br><span class="line"> %<span class="number">3</span> = stream.cmd.execute <span class="built_in">await</span>(%result_timepoint) => <span class="built_in">with</span>(%<span class="number">0</span> as %arg1: !stream.resource<external>{%c8}, %<span class="number">1</span> as %arg2: !stream.resource<external>{%c40}, %result as %arg3: !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>}) {</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_0::@predict_dispatch_0_matmul_1x10x2[%c1, %c10] {</span><br><span class="line"> ro %arg1[%c0 <span class="keyword">for</span> %c8] : !stream.resource<external>{%c8},</span><br><span class="line"> wo %arg3[%<span class="number">2</span>#<span class="number">1</span> <span class="keyword">for</span> %c40] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>}</span><br><span class="line"> }</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_1::@predict_dispatch_1_generic_10[%c1] {</span><br><span class="line"> ro %arg3[%<span class="number">2</span>#<span class="number">1</span> <span class="keyword">for</span> %c40] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>},</span><br><span class="line"> wo %arg3[%<span class="number">2</span>#<span class="number">2</span> <span class="keyword">for</span> %c4] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>}</span><br><span class="line"> }</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_2::@predict_dispatch_2_generic_1x10[%c1, %c10] {</span><br><span class="line"> ro %arg3[%<span class="number">2</span>#<span class="number">1</span> <span class="keyword">for</span> %c40] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>},</span><br><span class="line"> ro %arg3[%<span class="number">2</span>#<span class="number">2</span> <span class="keyword">for</span> %c4] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>},</span><br><span class="line"> wo %arg3[%<span class="number">2</span>#<span class="number">3</span> <span class="keyword">for</span> %c40] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>}</span><br><span class="line"> }</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_3::@predict_dispatch_3_generic_10[%c1] {</span><br><span class="line"> ro %arg3[%<span class="number">2</span>#<span class="number">3</span> <span class="keyword">for</span> %c40] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>},</span><br><span class="line"> wo %arg3[%<span class="number">2</span>#<span class="number">4</span> <span class="keyword">for</span> %c4] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>}</span><br><span class="line"> }</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_4::@predict_dispatch_4_generic_1x10[%c1, %c10] {</span><br><span class="line"> ro %arg3[%<span class="number">2</span>#<span class="number">3</span> <span class="keyword">for</span> %c40] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>},</span><br><span class="line"> ro %arg3[%<span class="number">2</span>#<span class="number">4</span> <span class="keyword">for</span> %c4] : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>},</span><br><span class="line"> wo %arg2[%c0_0 <span class="keyword">for</span> %c40] : !stream.resource<external>{%c40}</span><br><span class="line"> }</span><br><span class="line"> } => !stream.timepoint</span><br><span class="line"> <span class="comment">// 释放申请的临时空间</span></span><br><span class="line"> %<span class="number">4</span> = stream.resource.dealloca <span class="built_in">await</span>(%<span class="number">3</span>) => %result : !stream.resource<transient>{%<span class="number">2</span>#<span class="number">0</span>} => !stream.timepoint</span><br><span class="line"> %<span class="number">5</span> = stream.timepoint.join <span class="built_in">max</span>(%<span class="number">4</span>, %<span class="number">3</span>) => !stream.timepoint</span><br><span class="line"> %<span class="number">6</span> = stream.timepoint.await %<span class="number">5</span> => %<span class="number">1</span> : !stream.resource<external>{%c40}</span><br><span class="line"> %<span class="number">7</span> = stream.tensor.<span class="keyword">export</span> %<span class="number">6</span> : tensor<<span class="number">1</span>x10xf32> in !stream.resource<external>{%c40} -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">7</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">predict</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c8 = arith.constant <span class="number">8</span> : index</span><br><span class="line"> %c40 = arith.constant <span class="number">40</span> : index</span><br><span class="line"> %c4 = arith.constant <span class="number">4</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c10 = arith.constant <span class="number">10</span> : index</span><br><span class="line"> %c553648160_i32 = arith.constant <span class="number">553648160</span> : i32</span><br><span class="line"> %c1_i32 = arith.constant <span class="number">1</span> : i32</span><br><span class="line"> %c2 = arith.constant <span class="number">2</span> : index</span><br><span class="line"> hal.buffer_view.assert<%arg0 : !hal.buffer_view> <span class="built_in">message</span>(<span class="string">"tensor"</span>) <span class="built_in">shape</span>([%c1, %c2]) <span class="built_in">type</span>(%c553648160_i32) <span class="built_in">encoding</span>(%c1_i32)</span><br><span class="line"> %<span class="number">0</span> = stream.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x2xf32> in !stream.resource<external>{%c8}</span><br><span class="line"> %c0_0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">1</span> = stream.resource.alloc uninitialized : !stream.resource<external>{%c40}</span><br><span class="line"> %c0_1 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %c64 = arith.constant <span class="number">64</span> : index</span><br><span class="line"> %c64_2 = arith.constant <span class="number">64</span> : index</span><br><span class="line"> %c128 = arith.constant <span class="number">128</span> : index</span><br><span class="line"> %c128_3 = arith.constant <span class="number">128</span> : index</span><br><span class="line"> %c192 = arith.constant <span class="number">192</span> : index</span><br><span class="line"> %c192_4 = arith.constant <span class="number">192</span> : index</span><br><span class="line"> %result, %result_timepoint = stream.resource.alloca uninitialized : !stream.resource<transient>{%c192_4} => !stream.timepoint</span><br><span class="line"> %<span class="number">2</span> = stream.cmd.execute <span class="built_in">await</span>(%result_timepoint) => <span class="built_in">with</span>(%<span class="number">0</span> as %arg1: !stream.resource<external>{%c8}, %<span class="number">1</span> as %arg2: !stream.resource<external>{%c40}, %result as %arg3: !stream.resource<transient>{%c192_4}) {</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_0::@predict_dispatch_0_matmul_1x10x2[%c1, %c10] {</span><br><span class="line"> ro %arg1[%c0 <span class="keyword">for</span> %c8] : !stream.resource<external>{%c8},</span><br><span class="line"> wo %arg3[%c0_1 <span class="keyword">for</span> %c40] : !stream.resource<transient>{%c192_4}</span><br><span class="line"> }</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_1::@predict_dispatch_1_generic_10[%c1] {</span><br><span class="line"> ro %arg3[%c0_1 <span class="keyword">for</span> %c40] : !stream.resource<transient>{%c192_4},</span><br><span class="line"> wo %arg3[%c64_2 <span class="keyword">for</span> %c4] : !stream.resource<transient>{%c192_4}</span><br><span class="line"> }</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_2::@predict_dispatch_2_generic_1x10[%c1, %c10] {</span><br><span class="line"> ro %arg3[%c0_1 <span class="keyword">for</span> %c40] : !stream.resource<transient>{%c192_4},</span><br><span class="line"> ro %arg3[%c64_2 <span class="keyword">for</span> %c4] : !stream.resource<transient>{%c192_4},</span><br><span class="line"> wo %arg3[%c128_3 <span class="keyword">for</span> %c40] : !stream.resource<transient>{%c192_4}</span><br><span class="line"> }</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_3::@predict_dispatch_3_generic_10[%c1] {</span><br><span class="line"> ro %arg3[%c128_3 <span class="keyword">for</span> %c40] : !stream.resource<transient>{%c192_4},</span><br><span class="line"> wo %arg3[%c0_1 <span class="keyword">for</span> %c4] : !stream.resource<transient>{%c192_4}</span><br><span class="line"> }</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_4::@predict_dispatch_4_generic_1x10[%c1, %c10] {</span><br><span class="line"> ro %arg3[%c128_3 <span class="keyword">for</span> %c40] : !stream.resource<transient>{%c192_4},</span><br><span class="line"> ro %arg3[%c0_1 <span class="keyword">for</span> %c4] : !stream.resource<transient>{%c192_4},</span><br><span class="line"> wo %arg2[%c0_0 <span class="keyword">for</span> %c40] : !stream.resource<external>{%c40}</span><br><span class="line"> }</span><br><span class="line"> } => !stream.timepoint</span><br><span class="line"> %<span class="number">3</span> = stream.resource.dealloca <span class="built_in">await</span>(%<span class="number">2</span>) => %result : !stream.resource<transient>{%c192_4} => !stream.timepoint</span><br><span class="line"> %<span class="number">4</span> = stream.timepoint.join <span class="built_in">max</span>(%<span class="number">3</span>, %<span class="number">2</span>) => !stream.timepoint</span><br><span class="line"> %<span class="number">5</span> = stream.timepoint.await %<span class="number">4</span> => %<span class="number">1</span> : !stream.resource<external>{%c40}</span><br><span class="line"> %<span class="number">6</span> = stream.tensor.<span class="keyword">export</span> %<span class="number">5</span> : tensor<<span class="number">1</span>x10xf32> in !stream.resource<external>{%c40} -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">6</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>IREE::Util::createPropagateSubrangesPass</p><p>把resource转换成 (resource, size, offset, length)的元组。</p><ul><li><p>util.global</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">util.global <span class="keyword">private</span> @_constant : !stream.resource<constant></span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">util.global <span class="keyword">private</span> @_constant : !stream.resource<constant></span><br><span class="line">util.global <span class="keyword">private</span> @_constant_size : index</span><br><span class="line">util.global <span class="keyword">private</span> @_constant_offset : index</span><br><span class="line">util.global <span class="keyword">private</span> @_constant_length : index</span><br></pre></td></tr></table></figure></li><li><p>util.global.load</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">0</span> = util.global.load @foo : !stream.resource</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">0</span> = util.global.load @foo : !stream.resource</span><br><span class="line">%s = util.global.load @foo_size : index</span><br><span class="line">%o = util.global.load @foo_offset : index</span><br><span class="line">%l = util.global.load @foo_length : index</span><br><span class="line">%<span class="number">1</span> = stream.resource.subview %<span class="number">0</span>[%o] :</span><br><span class="line"> !stream.resource<*>{%s} -> !stream.resource<*>{%l}</span><br></pre></td></tr></table></figure></li><li><p>util.global.store</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">1</span> = stream.resource.subview %<span class="number">0</span>[%o] :</span><br><span class="line"> !stream.resource<*>{%s} -> !stream.resource<*>{%l}</span><br><span class="line">util.global.store %<span class="number">1</span>, @foo : !stream.resource</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">util.global.store %<span class="number">0</span>, @foo : !stream.resource <span class="comment">// 这里语义是正确的吗???</span></span><br><span class="line">util.global.store %s, @foo_size : index</span><br><span class="line">util.global.store %o, @foo_offset : index</span><br><span class="line">util.global.store %l, @foo_length : index</span><br></pre></td></tr></table></figure></li><li><p>func.func</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%<span class="number">0</span>: !stream.resource) {</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%<span class="number">0</span>: !stream.resource, %sz: index, %o: index, %l: index) {</span><br><span class="line"> %<span class="number">1</span> = stream.resource.subview %<span class="number">0</span>[%o] : {%sz} -> {%l}</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>call</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">1</span> = stream.resource.subview %<span class="number">0</span>[%o] : {%sz} -> {%l}</span><br><span class="line">%r = call @<span class="built_in">foo</span>(%<span class="number">1</span>)</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">%r, %rsz, %ro, %rl = call @<span class="built_in">foo</span>(%<span class="number">0</span>, %sz, %o, %l)</span><br><span class="line">%<span class="number">2</span> = stream.resource.subview %r[%ro] : {%rsz} -> {%rl}</span><br></pre></td></tr></table></figure></li><li><p>return</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">1</span> = stream.resource.subview %<span class="number">0</span>[%o] : {%sz} -> {%l}</span><br><span class="line"><span class="keyword">return</span> %<span class="number">1</span></span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">return</span> %<span class="number">0</span>, %sz, %o, %l</span><br></pre></td></tr></table></figure></li><li><p>branch</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">1</span> = stream.resource.subview %<span class="number">0</span>[%o] : {%sz} -> {%l}</span><br><span class="line">br ^<span class="built_in">bb1</span>(%<span class="number">1</span>)</span><br><span class="line"></span><br><span class="line">^<span class="built_in">bb1</span>(%b):</span><br><span class="line"> ...</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">br ^<span class="built_in">bb1</span>(%<span class="number">0</span>, %sz, %o, %l)</span><br><span class="line"> </span><br><span class="line">^<span class="built_in">bb1</span>(%a, %b, %c, %d):</span><br><span class="line"> %<span class="number">1</span> = stream.resource.subview %a[%b] : {%c} -> {%d}</span><br></pre></td></tr></table></figure></li><li><p>cond_branch</p></li></ul></li><li><p>addCleanupPatterns</p></li><li><p>IREE::Stream::createVerifyLoweringToCmdPass</p><p>验证program的合法性。</p></li></ul></li><li><p>buildStreamOptimizationPassPipeline</p><ul><li><p>addCleanupPatterns</p></li><li><p>mlir::createConvertSCFToCFPass</p><p>将structured control flow算子转换成更低层基础块形式的controlflow算子。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%pred: i32, %arg1: tensor<<span class="number">2</span>x10xf32>, %arg2: tensor<<span class="number">2</span>x10xf32>) -> tensor<<span class="number">2</span>x10xf32> {</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : i32</span><br><span class="line"> %<span class="number">0</span> = arith.cmpi sgt, %pred, %c0 : i32</span><br><span class="line"> %<span class="number">1</span> = scf.<span class="keyword">if</span> %<span class="number">0</span> -> (tensor<<span class="number">2</span>x10xf32>) {</span><br><span class="line"> %<span class="number">2</span> = mhlo.add %arg1, %arg2 : tensor<<span class="number">2</span>x10xf32></span><br><span class="line"> scf.yield %<span class="number">2</span> : tensor<<span class="number">2</span>x10xf32></span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> %<span class="number">2</span> = mhlo.subtract %arg1, %arg2 : tensor<<span class="number">2</span>x10xf32></span><br><span class="line"> scf.yield %<span class="number">2</span> : tensor<<span class="number">2</span>x10xf32></span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">1</span> : tensor<<span class="number">2</span>x10xf32></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%pred: i32, %arg1: tensor<<span class="number">2</span>x10xf32>, %arg2: tensor<<span class="number">2</span>x10xf32>) -> tensor<<span class="number">2</span>x10xf32> {</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : i32</span><br><span class="line"> %<span class="number">0</span> = arith.cmpi sgt, %pred, %c0 : i32</span><br><span class="line"> cf.cond_br %<span class="number">0</span>, ^bb1, ^bb2</span><br><span class="line"> ^bb1:</span><br><span class="line"> %<span class="number">2</span> = mhlo.add %arg1, %arg2 : tensor<<span class="number">2</span>x10xf32></span><br><span class="line"> cf.br ^<span class="built_in">bb3</span>(%<span class="number">2</span> : tensor<<span class="number">2</span>x10xf32>)</span><br><span class="line"> ^bb2:</span><br><span class="line"> %<span class="number">3</span> = mhlo.subtract %arg1, %arg2 : tensor<<span class="number">2</span>x10xf32></span><br><span class="line"> cf.br ^<span class="built_in">bb3</span>(%<span class="number">3</span> : tensor<<span class="number">2</span>x10xf32>)</span><br><span class="line"> ^<span class="built_in">bb3</span>(%<span class="number">4</span>: tensor<<span class="number">2</span>x10xf32>):</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">4</span> : tensor<<span class="number">2</span>x10xf32></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>addCleanupPatterns</p></li><li><p>IREE::Stream::createElideTimepointsPass</p><p>消除已经确信到达的等待。比如</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">%timepoint0 = ...</span><br><span class="line">%timepoint1 = ... <span class="built_in">await</span>(%timepoint0)</span><br><span class="line">%timepoint2 = stream.timepoint.join <span class="built_in">max</span>(%timepoint0, %timepoint1)</span><br></pre></td></tr></table></figure><p>timepoint1到达时timepoint0一定已经达到过,因此可以转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">%timepoint0 = ...</span><br><span class="line">%timepoint1 = ... <span class="built_in">await</span>(%timepoint0)</span><br><span class="line">%timepoint2 = stream.timepoint.join <span class="built_in">max</span>(%timepoint1)</span><br></pre></td></tr></table></figure><p>canonicalization之后最终是</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">%timepoint0 = ...</span><br><span class="line">%timepoint1 = ... <span class="built_in">await</span>(%timepoint0)</span><br><span class="line">%timepoint2 = %timepoint1</span><br></pre></td></tr></table></figure></li><li><p>IREE::Util::createFixedPointIteratorPass</p><p>该pass触发重复执行一个passpipeline,直到达到固定迭代次数或最大迭代次数。这里的pipeline包括前面的addCleanupPatterns和createElideTimepointsPass两个子pass。</p></li><li><p>IREE::Stream::createFuseDispatchBindingsPass</p><p>根据<code>stream.cmd.dispatch</code> 的resource关系合并dispatchexecutable的bindings,比如<code>stream.cmd.dispatch</code>两个resource是同一个地址的不同range,则可以计算每个resource在base地址上的偏移,并将这两个resource合并成一个binding,在dispatchexecutable中根据偏移来截取每个被合并的binding。该操作默认只合并readonly的resource。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br></pre></td><td class="code"><pre><span class="line">stream.executable <span class="keyword">private</span> @predict_dispatch_2 {</span><br><span class="line"> stream.executable.<span class="keyword">export</span> <span class="keyword">public</span> @<span class="function">predict_dispatch_2_generic_1x10 <span class="title">workgroups</span><span class="params">(%arg0: index, %arg1: index)</span> -> <span class="params">(index, index, index)</span> </span>{</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1</span><br><span class="line"> stream.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> builtin.<span class="keyword">module</span> {</span><br><span class="line"> func.func @<span class="built_in">predict_dispatch_2_generic_1x10</span>(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding) {</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<<span class="number">1</span>x10xf32>></span><br><span class="line"> %<span class="number">1</span> = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<f32>></span><br><span class="line"> %<span class="number">2</span> = stream.binding.subspan %arg2[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<<span class="number">1</span>x10xf32>></span><br><span class="line"> %<span class="number">3</span> = flow.dispatch.tensor.load %<span class="number">0</span>, offsets = [<span class="number">0</span>, <span class="number">0</span>], sizes = [<span class="number">1</span>, <span class="number">10</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">1</span>x10xf32>> -> tensor<<span class="number">1</span>x10xf32></span><br><span class="line"> %<span class="number">4</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:tensor<f32>> -> tensor<f32></span><br><span class="line"> %<span class="number">5</span> = tensor.<span class="built_in">empty</span>() : tensor<<span class="number">1</span>x10xf32></span><br><span class="line"> %<span class="number">6</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> ()>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">3</span>, %<span class="number">4</span> : tensor<<span class="number">1</span>x10xf32>, tensor<f32>) <span class="built_in">outs</span>(%<span class="number">5</span> : tensor<<span class="number">1</span>x10xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">7</span> = arith.subf %in, %in_0 : f32</span><br><span class="line"> %<span class="number">8</span> = math.exp %<span class="number">7</span> : f32</span><br><span class="line"> linalg.yield %<span class="number">8</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">1</span>x10xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">6</span>, %<span class="number">2</span>, offsets = [<span class="number">0</span>, <span class="number">0</span>], sizes = [<span class="number">1</span>, <span class="number">10</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : tensor<<span class="number">1</span>x10xf32> -> !flow.dispatch.tensor<writeonly:tensor<<span class="number">1</span>x10xf32>></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line">func.func @<span class="built_in">predict</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> ...</span><br><span class="line"> %<span class="number">2</span> = stream.cmd.execute <span class="built_in">await</span>(%result_timepoint) => <span class="built_in">with</span>(%<span class="number">0</span> as %arg1: !stream.resource<external>{%c8}, %<span class="number">1</span> as %arg2: !stream.resource<external>{%c40}, %result as %arg3: !stream.resource<transient>{%c192}) {</span><br><span class="line"> ...</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_2::@predict_dispatch_2_generic_1x10[%c1, %c10] {</span><br><span class="line"> ro %arg3[%c0 <span class="keyword">for</span> %c40] : !stream.resource<transient>{%c192},</span><br><span class="line"> ro %arg3[%c64 <span class="keyword">for</span> %c4] : !stream.resource<transient>{%c192},</span><br><span class="line"> wo %arg3[%c128 <span class="keyword">for</span> %c40] : !stream.resource<transient>{%c192}</span><br><span class="line"> }</span><br><span class="line"> ...</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br></pre></td><td class="code"><pre><span class="line">stream.executable <span class="keyword">private</span> @predict_dispatch_2 {</span><br><span class="line"> stream.executable.<span class="keyword">export</span> <span class="keyword">public</span> @<span class="function">predict_dispatch_2_generic_1x10 <span class="title">workgroups</span><span class="params">(%arg0: index, %arg1: index)</span> -> <span class="params">(index, index, index)</span> </span>{</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1</span><br><span class="line"> stream.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> builtin.<span class="keyword">module</span> {</span><br><span class="line"> func.func @<span class="built_in">predict_dispatch_2_generic_1x10</span>(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: index, %arg3: index, %arg4: index) {</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = arith.addi %c0, %arg2 : index</span><br><span class="line"> %<span class="number">1</span> = stream.binding.subspan %arg0[%<span class="number">0</span>] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<<span class="number">1</span>x10xf32>></span><br><span class="line"> %<span class="number">2</span> = arith.addi %c0, %arg3 : index</span><br><span class="line"> %<span class="number">3</span> = stream.binding.subspan %arg0[%<span class="number">2</span>] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<f32>></span><br><span class="line"> %<span class="number">4</span> = arith.addi %c0, %arg4 : index</span><br><span class="line"> %<span class="number">5</span> = stream.binding.subspan %arg1[%<span class="number">4</span>] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<<span class="number">1</span>x10xf32>></span><br><span class="line"> %<span class="number">6</span> = flow.dispatch.tensor.load %<span class="number">1</span>, offsets = [<span class="number">0</span>, <span class="number">0</span>], sizes = [<span class="number">1</span>, <span class="number">10</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">1</span>x10xf32>> -> tensor<<span class="number">1</span>x10xf32></span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tensor.load %<span class="number">3</span>, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:tensor<f32>> -> tensor<f32></span><br><span class="line"> %<span class="number">8</span> = tensor.<span class="built_in">empty</span>() : tensor<<span class="number">1</span>x10xf32></span><br><span class="line"> %<span class="number">9</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> ()>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">6</span>, %<span class="number">7</span> : tensor<<span class="number">1</span>x10xf32>, tensor<f32>) <span class="built_in">outs</span>(%<span class="number">8</span> : tensor<<span class="number">1</span>x10xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">10</span> = arith.subf %in, %in_0 : f32</span><br><span class="line"> %<span class="number">11</span> = math.exp %<span class="number">10</span> : f32</span><br><span class="line"> linalg.yield %<span class="number">11</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">1</span>x10xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">9</span>, %<span class="number">5</span>, offsets = [<span class="number">0</span>, <span class="number">0</span>], sizes = [<span class="number">1</span>, <span class="number">10</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : tensor<<span class="number">1</span>x10xf32> -> !flow.dispatch.tensor<writeonly:tensor<<span class="number">1</span>x10xf32>></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line">func.func @<span class="built_in">predict</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> ...</span><br><span class="line"> %<span class="number">2</span> = stream.cmd.execute <span class="built_in">await</span>(%result_timepoint) => <span class="built_in">with</span>(%<span class="number">0</span> as %arg1: !stream.resource<external>{%c8}, %<span class="number">1</span> as %arg2: !stream.resource<external>{%c40}, %result as %arg3: !stream.resource<transient>{%c192}) {</span><br><span class="line"> ...</span><br><span class="line"> stream.cmd.dispatch @predict_dispatch_2::@predict_dispatch_2_generic_1x10[%c1, %c10](%c0, %c64, %c128 : index, index, index) {</span><br><span class="line"> ro %arg3[%c0_0 <span class="keyword">for</span> %c192] : !stream.resource<transient>{%c192},</span><br><span class="line"> wo %arg3[%c0_0 <span class="keyword">for</span> %c192] : !stream.resource<transient>{%c192}</span><br><span class="line"> }</span><br><span class="line"> ...</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>可以看到<code>stream.cmd.dispatch @predict_dispatch_2</code>的resource被合并为2个,<code>predict_dispatch_2_generic_1x10</code>dispatchexecutable参数中的binding也减少为2个,但增加了3个表示offset的index,被合并的binding根据offset来截取。</p></li><li><p>IREE::Stream::createPackDispatchOperandsPass</p><p>将dispatch executable参数中的标量/index类型转换成i32或i64类型。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">predict_dispatch_2_generic_1x10</span>(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: index, %arg3: index, %arg4: index) {</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">predict_dispatch_2_generic_1x10</span>(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: i32, %arg3: i32, %arg4: i32) {</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>mlir::createCSEPass</p></li><li><p>IREE::Stream::createFoldUniformOperandsPass</p><p>折叠dispatch executable的所有调用中相同的参数。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">stream.cmd.dispatch @<span class="built_in">foo</span>(%c1, %c100 : index, index)</span><br><span class="line">stream.cmd.dispatch @<span class="built_in">foo</span>(%c1, %c101 : index, index)</span><br><span class="line">stream.cmd.dispatch @<span class="built_in">foo2</span>(%c1, %c101 : index, index)</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">stream.cmd.dispatch @<span class="built_in">foo</span>(%c100 : index)</span><br><span class="line">stream.cmd.dispatch @<span class="built_in">foo</span>(%c101 : index)</span><br><span class="line">stream.cmd.dispatch @<span class="built_in">foo2</span>()</span><br></pre></td></tr></table></figure><p><code>@foo</code>内联了<code>%c1</code>,<code>@foo2</code>内联了<code>%c1</code>和<code>%c101</code>。</p></li><li><p>IREE::Stream::createAnnotateDispatchArgumentsPass</p><p>给dispatch executable的参数添加potential value和alignment信息。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">predict_dispatch_2_generic_1x10</span>(%arg0: !stream.binding, %arg1: !stream.binding) {</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换为</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">predict_dispatch_2_generic_1x10</span>(%arg0: !stream.binding {stream.alignment = <span class="number">64</span> : index}, %arg1: !stream.binding {stream.alignment = <span class="number">64</span> : index}) {</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li></ul></li><li><p>IREE::Stream::createMemoizeChannelsPass</p><p>找出所有<code>stream.channel.default</code>ops,为每一个<code>stream.channel.default</code>op创建一个全局缓冲区,同时在初始化时创建对应的channel,并将channel结果写入全局缓冲区,最后将该<code>stream.channel.default</code>op替换为全局缓冲区的<code>util.global.load</code> op。</p></li><li><p>addCleanupPatterns</p></li><li><p>mlir::createSymbolDCEPass</p></li></ul>]]></content>
<summary type="html"><p>IREE::Stream::StreamTransformPassPipeline
的主要作用是将program转换到stream
dialect,优化变量编码方式,划分调度子图,生成异步调度策略,并实现内存规划策略。</p></summary>
<category term="DL Compiler" scheme="https://hjchen2.github.io/categories/DL-Compiler/"/>
<category term="Deep Learning Compiler" scheme="https://hjchen2.github.io/tags/Deep-Learning-Compiler/"/>
<category term="IREE" scheme="https://hjchen2.github.io/tags/IREE/"/>
</entry>
<entry>
<title>IREE编译流程解析(四)</title>
<link href="https://hjchen2.github.io/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B4/"/>
<id>https://hjchen2.github.io/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B4/</id>
<published>2023-01-04T13:15:20.000Z</published>
<updated>2023-02-17T11:57:53.436Z</updated>
<content type="html"><![CDATA[<p>IREEFlow::buildFlowTransformPassPipeline主要作用是执行一系列窥孔优化,比如1x1的conv2d转换成matmul、tiling、opfusion等,最终将workload拆分成<code>flow.executable</code>。相关的passes及其作用如下。</p><span id="more"></span><ul><li><p>IREE::Util::createDemoteF64ToF32Pass</p><p>将F64类型窄化为F32。</p></li><li><p>IREE::Flow::createConvertConv2D1x1ToMatmulPass</p><p>将1x1的<code>linalg.conv_2d_nhwc_hwcf</code>转换成<code>linalg.matmul</code>。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// func.func @conv(%input : tensor<1x2x2x3xf32>, %filter: tensor<1x1x3x4xf32>) -> tensor<1x2x2x4xf32> {</span></span><br><span class="line"><span class="comment">// %0 = mhlo.convolution(%input, %filter)</span></span><br><span class="line"><span class="comment">// dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],</span></span><br><span class="line"><span class="comment">// window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]}</span></span><br><span class="line"><span class="comment">// {batch_group_count = 1 : i64, feature_group_count = 1 : i64}</span></span><br><span class="line"><span class="comment">// : (tensor<1x2x2x3xf32>, tensor<1x1x3x4xf32>) -> tensor<1x2x2x4xf32></span></span><br><span class="line"><span class="comment">// return %0 : tensor<1x2x2x4xf32></span></span><br><span class="line"><span class="comment">// }</span></span><br><span class="line">func.func @<span class="built_in">conv</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x2x2x3xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">1</span>x1x3x4xf32></span><br><span class="line"> %<span class="number">2</span> = linalg.init_tensor [<span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">4</span>] : tensor<<span class="number">1</span>x2x2x4xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">1</span>x2x2x4xf32>) -> tensor<<span class="number">1</span>x2x2x4xf32></span><br><span class="line"> %<span class="number">4</span> = linalg.conv_2d_nhwc_hwcf {dilations = dense<<span class="number">1</span>> : tensor<<span class="number">2</span>xi64>, strides = dense<<span class="number">1</span>> : tensor<<span class="number">2</span>xi64>} <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">1</span>x2x2x3xf32>, tensor<<span class="number">1</span>x1x3x4xf32>) <span class="built_in">outs</span>(%<span class="number">3</span> : tensor<<span class="number">1</span>x2x2x4xf32>) -> tensor<<span class="number">1</span>x2x2x4xf32></span><br><span class="line"> %<span class="number">5</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">4</span> : tensor<<span class="number">1</span>x2x2x4xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">5</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">conv</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x2x2x3xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">1</span>x1x3x4xf32></span><br><span class="line"> %<span class="number">2</span> = linalg.init_tensor [<span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">4</span>] : tensor<<span class="number">1</span>x2x2x4xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">1</span>x2x2x4xf32>) -> tensor<<span class="number">1</span>x2x2x4xf32></span><br><span class="line"> %<span class="number">4</span> = tensor.collapse_shape %<span class="number">0</span> [[<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>], [<span class="number">3</span>]] : tensor<<span class="number">1</span>x2x2x3xf32> into tensor<<span class="number">4</span>x3xf32></span><br><span class="line"> %<span class="number">5</span> = tensor.collapse_shape %<span class="number">1</span> [[<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>], [<span class="number">3</span>]] : tensor<<span class="number">1</span>x1x3x4xf32> into tensor<<span class="number">3</span>x4xf32></span><br><span class="line"> %<span class="number">6</span> = tensor.collapse_shape %<span class="number">3</span> [[<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>], [<span class="number">3</span>]] : tensor<<span class="number">1</span>x2x2x4xf32> into tensor<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">7</span> = linalg.matmul <span class="built_in">ins</span>(%<span class="number">4</span>, %<span class="number">5</span> : tensor<<span class="number">4</span>x3xf32>, tensor<<span class="number">3</span>x4xf32>) <span class="built_in">outs</span>(%<span class="number">6</span> : tensor<<span class="number">4</span>x4xf32>) -> tensor<<span class="number">4</span>x4xf32></span><br><span class="line"> %<span class="number">8</span> = tensor.expand_shape %<span class="number">7</span> [[<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>], [<span class="number">3</span>]] : tensor<<span class="number">4</span>x4xf32> into tensor<<span class="number">1</span>x2x2x4xf32></span><br><span class="line"> %<span class="number">9</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">8</span> : tensor<<span class="number">1</span>x2x2x4xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">9</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>IREE::Flow::createConvertConv2DToImg2ColPass</p><p>将conv2d转换成img2col。默认不开启。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// %0 = mhlo.convolution(%input, %filter)</span></span><br><span class="line"><span class="comment">// dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],</span></span><br><span class="line"><span class="comment">// window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]}</span></span><br><span class="line"><span class="comment">// {batch_group_count = 1 : i64, feature_group_count = 1 : i64}</span></span><br><span class="line"><span class="comment">// : (tensor<1x4x4x3xf32>, tensor<2x2x3x4xf32>) -> tensor<1x3x3x4xf32></span></span><br><span class="line">func.func @<span class="built_in">conv</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x4x4x3xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">2</span>x2x3x4xf32></span><br><span class="line"> %<span class="number">2</span> = linalg.init_tensor [<span class="number">1</span>, <span class="number">3</span>, <span class="number">3</span>, <span class="number">4</span>] : tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">1</span>x3x3x4xf32>) -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">4</span> = linalg.conv_2d_nhwc_hwcf {dilations = dense<<span class="number">1</span>> : tensor<<span class="number">2</span>xi64>, strides = dense<<span class="number">1</span>> : tensor<<span class="number">2</span>xi64>} <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">1</span>x4x4x3xf32>, tensor<<span class="number">2</span>x2x3x4xf32>) <span class="built_in">outs</span>(%<span class="number">3</span> : tensor<<span class="number">1</span>x3x3x4xf32>) -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">5</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">4</span> : tensor<<span class="number">1</span>x3x3x4xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">5</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">conv</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x4x4x3xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">2</span>x2x3x4xf32></span><br><span class="line"> %<span class="number">2</span> = linalg.init_tensor [<span class="number">1</span>, <span class="number">3</span>, <span class="number">3</span>, <span class="number">4</span>] : tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">1</span>x3x3x4xf32>) -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">4</span> = linalg.init_tensor [<span class="number">1</span>, <span class="number">3</span>, <span class="number">3</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">3</span>] : tensor<<span class="number">1</span>x3x3x2x2x3xf32></span><br><span class="line"> %<span class="number">5</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>, <span class="built_in">affine_map</span><(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">0</span> : tensor<<span class="number">1</span>x4x4x3xf32>) <span class="built_in">outs</span>(%<span class="number">4</span> : tensor<<span class="number">1</span>x3x3x2x2x3xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg2: f32, %arg3: f32):</span><br><span class="line"> linalg.yield %arg2 : f32</span><br><span class="line"> } -> tensor<<span class="number">1</span>x3x3x2x2x3xf32></span><br><span class="line"> %<span class="number">6</span> = tensor.collapse_shape %<span class="number">5</span> [[<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>], [<span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>]] : tensor<<span class="number">1</span>x3x3x2x2x3xf32> into tensor<<span class="number">9</span>x12xf32></span><br><span class="line"> %<span class="number">7</span> = tensor.collapse_shape %<span class="number">1</span> [[<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>], [<span class="number">3</span>]] : tensor<<span class="number">2</span>x2x3x4xf32> into tensor<<span class="number">12</span>x4xf32></span><br><span class="line"> %<span class="number">8</span> = tensor.collapse_shape %<span class="number">3</span> [[<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>], [<span class="number">3</span>]] : tensor<<span class="number">1</span>x3x3x4xf32> into tensor<<span class="number">9</span>x4xf32></span><br><span class="line"> %<span class="number">9</span> = linalg.matmul <span class="built_in">ins</span>(%<span class="number">6</span>, %<span class="number">7</span> : tensor<<span class="number">9</span>x12xf32>, tensor<<span class="number">12</span>x4xf32>) <span class="built_in">outs</span>(%<span class="number">8</span> : tensor<<span class="number">9</span>x4xf32>) -> tensor<<span class="number">9</span>x4xf32></span><br><span class="line"> %<span class="number">10</span> = tensor.expand_shape %<span class="number">9</span> [[<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>], [<span class="number">3</span>]] : tensor<<span class="number">9</span>x4xf32> into tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">11</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">10</span> : tensor<<span class="number">1</span>x3x3x4xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">11</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>IREE::Flow::createDetachElementwiseFromNamedOpsPass</p><p>将<code>buffer = linalg.generic_op + linalg.named_payload_op</code>转换成<code>tmp_buffer = linalg.named_payload_op; buffer = linalg.generic_op + tmp_buffer</code>,主要目的是将上游的<code>generic op</code>和<code>named_payload_op</code>分隔开,使得<code>named_payload_op</code>的结果写到一块新的buffer。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x4x4x3xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">2</span>x2x3x4xf32></span><br><span class="line"> %<span class="number">2</span> = hal.tensor.<span class="keyword">import</span> %arg2 : !hal.buffer_view -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> </span><br><span class="line"> %<span class="number">3</span> = linalg.init_tensor [<span class="number">1</span>, <span class="number">3</span>, <span class="number">3</span>, <span class="number">4</span>] : tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">4</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">3</span> : tensor<<span class="number">1</span>x3x3x4xf32>) -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">5</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, <span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">2</span> : tensor<<span class="number">1</span>x3x3x4xf32>) <span class="built_in">outs</span>(%<span class="number">4</span> : tensor<<span class="number">1</span>x3x3x4xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg3: f32, %arg4: f32):</span><br><span class="line"> %<span class="number">8</span> = arith.addf %arg3, %arg3 : f32</span><br><span class="line"> linalg.yield %<span class="number">8</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> </span><br><span class="line"> %<span class="number">6</span> = linalg.conv_2d_nhwc_hwcf {dilations = dense<<span class="number">1</span>> : tensor<<span class="number">2</span>xi64>, strides = dense<<span class="number">1</span>> : tensor<<span class="number">2</span>xi64>} <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">1</span>x4x4x3xf32>, tensor<<span class="number">2</span>x2x3x4xf32>) <span class="built_in">outs</span>(%<span class="number">5</span> : tensor<<span class="number">1</span>x3x3x4xf32>) -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">7</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">6</span> : tensor<<span class="number">1</span>x3x3x4xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">7</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x4x4x3xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">2</span>x2x3x4xf32></span><br><span class="line"> %<span class="number">2</span> = hal.tensor.<span class="keyword">import</span> %arg2 : !hal.buffer_view -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> </span><br><span class="line"> %<span class="number">3</span> = linalg.init_tensor [<span class="number">1</span>, <span class="number">3</span>, <span class="number">3</span>, <span class="number">4</span>] : tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">4</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">3</span> : tensor<<span class="number">1</span>x3x3x4xf32>) -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">5</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, <span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">2</span> : tensor<<span class="number">1</span>x3x3x4xf32>) <span class="built_in">outs</span>(%<span class="number">4</span> : tensor<<span class="number">1</span>x3x3x4xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg3: f32, %arg4: f32):</span><br><span class="line"> %<span class="number">11</span> = arith.addf %arg3, %arg3 : f32</span><br><span class="line"> linalg.yield %<span class="number">11</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> </span><br><span class="line"> %<span class="number">6</span> = linalg.init_tensor [<span class="number">1</span>, <span class="number">3</span>, <span class="number">3</span>, <span class="number">4</span>] : tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">7</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">6</span> : tensor<<span class="number">1</span>x3x3x4xf32>) -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">8</span> = linalg.conv_2d_nhwc_hwcf {dilations = dense<<span class="number">1</span>> : tensor<<span class="number">2</span>xi64>, strides = dense<<span class="number">1</span>> : tensor<<span class="number">2</span>xi64>} <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">1</span>x4x4x3xf32>, tensor<<span class="number">2</span>x2x3x4xf32>) <span class="built_in">outs</span>(%<span class="number">7</span> : tensor<<span class="number">1</span>x3x3x4xf32>) -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"></span><br><span class="line"> %<span class="number">9</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, <span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, <span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">8</span>, %<span class="number">5</span> : tensor<<span class="number">1</span>x3x3x4xf32>, tensor<<span class="number">1</span>x3x3x4xf32>) <span class="built_in">outs</span>(%<span class="number">7</span> : tensor<<span class="number">1</span>x3x3x4xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg3: f32, %arg4: f32, %arg5: f32):</span><br><span class="line"> %<span class="number">11</span> = arith.addf %arg3, %arg4 : f32</span><br><span class="line"> linalg.yield %<span class="number">11</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">1</span>x3x3x4xf32></span><br><span class="line"> %<span class="number">10</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">9</span> : tensor<<span class="number">1</span>x3x3x4xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">10</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>IREE::Flow::createVerifyInputLegalityPass</p><p>验证program是否合法。</p></li><li><p>IREE::Flow::createConvertLinalgMatmulToMmt4DPass</p><p>将2d的<code>linalg.matmul</code>tiling成<code>linalg.mmt4d</code>。默认不开启,可通过<code>--iree-flow-mmt4d-target-options="enable_generic_slow arch=cuda"</code>选项开启。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">256</span>x256xf32></span><br><span class="line"> %<span class="number">2</span> = linalg.init_tensor [<span class="number">128</span>, <span class="number">256</span>] : tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">128</span>x256xf32>) -> tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">4</span> = linalg.matmul <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">128</span>x256xf32>, tensor<<span class="number">256</span>x256xf32>) <span class="built_in">outs</span>(%<span class="number">3</span> : tensor<<span class="number">128</span>x256xf32>) -> tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">5</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">4</span> : tensor<<span class="number">128</span>x256xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">5</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">256</span>x256xf32></span><br><span class="line"> %<span class="number">2</span> = linalg.init_tensor [<span class="number">128</span>, <span class="number">256</span>] : tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">128</span>x256xf32>) -> tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">4</span> = tensor.expand_shape %<span class="number">0</span> [[<span class="number">0</span>, <span class="number">1</span>], [<span class="number">2</span>, <span class="number">3</span>]] : tensor<<span class="number">128</span>x256xf32> into tensor<<span class="number">16</span>x8x128x2xf32></span><br><span class="line"> %<span class="number">5</span> = tensor.expand_shape %<span class="number">1</span> [[<span class="number">0</span>, <span class="number">1</span>], [<span class="number">2</span>, <span class="number">3</span>]] : tensor<<span class="number">256</span>x256xf32> into tensor<<span class="number">128</span>x2x64x4xf32></span><br><span class="line"> %<span class="number">6</span> = tensor.expand_shape %<span class="number">3</span> [[<span class="number">0</span>, <span class="number">1</span>], [<span class="number">2</span>, <span class="number">3</span>]] : tensor<<span class="number">128</span>x256xf32> into tensor<<span class="number">16</span>x8x64x4xf32></span><br><span class="line"> %<span class="number">7</span> = linalg.init_tensor [<span class="number">16</span>, <span class="number">128</span>, <span class="number">8</span>, <span class="number">2</span>] : tensor<<span class="number">16</span>x128x8x2xf32></span><br><span class="line"> %<span class="number">8</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, <span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">4</span> : tensor<<span class="number">16</span>x8x128x2xf32>) <span class="built_in">outs</span>(%<span class="number">7</span> : tensor<<span class="number">16</span>x128x8x2xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg2: f32, %arg3: f32):</span><br><span class="line"> linalg.yield %arg2 : f32</span><br><span class="line"> } -> tensor<<span class="number">16</span>x128x8x2xf32></span><br><span class="line"> %<span class="number">9</span> = linalg.init_tensor [<span class="number">64</span>, <span class="number">128</span>, <span class="number">4</span>, <span class="number">2</span>] : tensor<<span class="number">64</span>x128x4x2xf32></span><br><span class="line"> %<span class="number">10</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d1, d3, d0, d2)>, <span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">5</span> : tensor<<span class="number">128</span>x2x64x4xf32>) <span class="built_in">outs</span>(%<span class="number">9</span> : tensor<<span class="number">64</span>x128x4x2xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg2: f32, %arg3: f32):</span><br><span class="line"> linalg.yield %arg2 : f32</span><br><span class="line"> } -> tensor<<span class="number">64</span>x128x4x2xf32></span><br><span class="line"> %<span class="number">11</span> = linalg.init_tensor [<span class="number">16</span>, <span class="number">64</span>, <span class="number">8</span>, <span class="number">4</span>] : tensor<<span class="number">16</span>x64x8x4xf32></span><br><span class="line"> %<span class="number">12</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, <span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">6</span> : tensor<<span class="number">16</span>x8x64x4xf32>) <span class="built_in">outs</span>(%<span class="number">11</span> : tensor<<span class="number">16</span>x64x8x4xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg2: f32, %arg3: f32):</span><br><span class="line"> linalg.yield %arg2 : f32</span><br><span class="line"> } -> tensor<<span class="number">16</span>x64x8x4xf32></span><br><span class="line"> <span class="comment">// 16 x (128x8x2) @ 64 x (128x4x2) => 16 x 64 x sum_{128}(8x2 * (4x2)^T)</span></span><br><span class="line"> %<span class="number">13</span> = linalg.mmt4d {comment = <span class="string">"generic tiling parameters, as no known kernel was matched for this matmul and target"</span>} <span class="built_in">ins</span>(%<span class="number">8</span>, %<span class="number">10</span> : tensor<<span class="number">16</span>x128x8x2xf32>, tensor<<span class="number">64</span>x128x4x2xf32>) <span class="built_in">outs</span>(%<span class="number">12</span> : tensor<<span class="number">16</span>x64x8x4xf32>) -> tensor<<span class="number">16</span>x64x8x4xf32></span><br><span class="line"> %<span class="number">14</span> = linalg.init_tensor [<span class="number">16</span>, <span class="number">8</span>, <span class="number">64</span>, <span class="number">4</span>] : tensor<<span class="number">16</span>x8x64x4xf32></span><br><span class="line"> %<span class="number">15</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, <span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">13</span> : tensor<<span class="number">16</span>x64x8x4xf32>) <span class="built_in">outs</span>(%<span class="number">14</span> : tensor<<span class="number">16</span>x8x64x4xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg2: f32, %arg3: f32):</span><br><span class="line"> linalg.yield %arg2 : f32</span><br><span class="line"> } -> tensor<<span class="number">16</span>x8x64x4xf32></span><br><span class="line"> %<span class="number">16</span> = tensor.collapse_shape %<span class="number">15</span> [[<span class="number">0</span>, <span class="number">1</span>], [<span class="number">2</span>, <span class="number">3</span>]] : tensor<<span class="number">16</span>x8x64x4xf32> into tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">17</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">16</span> : tensor<<span class="number">128</span>x256xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">17</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>IREE::Flow::createPadLinalgOpsToIntegerMultiplePass</p><p>将matmul的M、N和K扩充到paddingSize的整数倍,paddingSize默认为4。</p></li><li><p>mlir::createLinalgNamedOpConversionPass</p><p>将depth_multiplier=1的<code>linalg.depthwise_conv_2d_nhwc_hwcm</code>转换成<code>linalg.depthwise_conv_2d_nhwc_hwc</code>,将depth_multiplier=1的<code>linalg.depthwise_conv_2d_nhwc_hwcm_q</code>转换成<code>linalg.depthwise_conv_2d_nhwc_hwc_q</code>。</p><p>depth_multiplier的作用见https://www.tensorflow.org/api_docs/python/tf/keras/layers/DepthwiseConv2D。</p><figure class="highlight txt"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">The number of depthwise convolution output channels for each input channel. The total number of depthwise convolution output channels will be equal to filters_in * depth_multiplier.</span><br></pre></td></tr></table></figure></li><li><p>IREE::Flow::createExpandTensorShapesPass</p><p>将dynamic tensor扩充为tensor + dynamicdim的对偶形式,这么做的一个好处是动态维度可以直接参与计算和推导。比如</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// func.func private @add(%arg0 : tensor<?x2xf32>, %arg1 : tensor<?x2xf32>) -> tensor<?x2xf32></span></span><br><span class="line"><span class="comment">// iree_input.global private mutable @param : tensor<?x2xf32></span></span><br><span class="line"><span class="comment">// func.func @run(%arg0 : tensor<?x2xf32>) -> tensor<?x2xf32> {</span></span><br><span class="line"><span class="comment">// %0 = iree_input.global.load @param : tensor<?x2xf32></span></span><br><span class="line"><span class="comment">// %1 = call @add(%0, %arg0) : (tensor<?x2xf32>, tensor<?x2xf32>) -> tensor<?x2xf32></span></span><br><span class="line"><span class="comment">// iree_input.global.store %1, @param : tensor<?x2xf32></span></span><br><span class="line"><span class="comment">// return %1 : tensor<?x2xf32></span></span><br><span class="line"><span class="comment">// }</span></span><br><span class="line">func.func <span class="keyword">private</span> @<span class="built_in">add</span>(!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub}</span><br><span class="line">util.global <span class="keyword">private</span> <span class="keyword">mutable</span> @param : tensor<?x2xf32></span><br><span class="line">func.func @<span class="built_in">run</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %param = util.global.load @param : tensor<?x2xf32></span><br><span class="line"> %dim = tensor.dim %param, %c0 : tensor<?x2xf32></span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">export</span> %param : tensor<?x2xf32>{%dim} -> !hal.buffer_view</span><br><span class="line"> %<span class="number">1</span> = call @<span class="built_in">add</span>(%<span class="number">0</span>, %arg0) : (!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view</span><br><span class="line"> %<span class="number">2</span> = hal.buffer_view.dim<%<span class="number">1</span> : !hal.buffer_view>[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">3</span> = hal.tensor.<span class="keyword">import</span> %<span class="number">1</span> : !hal.buffer_view -> tensor<?x2xf32>{%<span class="number">2</span>}</span><br><span class="line"> util.global.store %<span class="number">3</span>, @param : tensor<?x2xf32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">1</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>被转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line">func.func <span class="keyword">private</span> @<span class="built_in">add</span>(!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub}</span><br><span class="line">util.global <span class="keyword">private</span> <span class="keyword">mutable</span> @param : tensor<?x2xf32></span><br><span class="line">util.global <span class="keyword">private</span> <span class="keyword">mutable</span> @param__d0 : index</span><br><span class="line">func.func @<span class="built_in">run</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %param = util.global.load @param : tensor<?x2xf32></span><br><span class="line"> %param__d0 = util.global.load @param__d0 : index</span><br><span class="line"> %<span class="number">0</span> = flow.tensor.tie_shape %param : tensor<?x2xf32>{%param__d0}</span><br><span class="line"> %dim = tensor.dim %<span class="number">0</span>, %c0 : tensor<?x2xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">0</span> : tensor<?x2xf32>{%dim} -> !hal.buffer_view</span><br><span class="line"> %<span class="number">2</span> = call @<span class="built_in">add</span>(%<span class="number">1</span>, %arg0) : (!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view</span><br><span class="line"> %<span class="number">3</span> = hal.buffer_view.dim<%<span class="number">2</span> : !hal.buffer_view>[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">4</span> = hal.tensor.<span class="keyword">import</span> %<span class="number">2</span> : !hal.buffer_view -> tensor<?x2xf32>{%<span class="number">3</span>}</span><br><span class="line"> util.global.store %<span class="number">4</span>, @param : tensor<?x2xf32></span><br><span class="line"> util.global.store %<span class="number">3</span>, @param__d0 : index</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">2</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>从中可以看出几点变化:</p><ul><li><p>global tensor增加了一个表示动态维度的global index。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">util.global <span class="keyword">private</span> <span class="keyword">mutable</span> @param : tensor<?x2xf32></span><br><span class="line"></span><br><span class="line">转换成:</span><br><span class="line">util.global <span class="keyword">private</span> <span class="keyword">mutable</span> @param : tensor<?x2xf32></span><br><span class="line">util.global <span class="keyword">private</span> <span class="keyword">mutable</span> @param__d0 : index</span><br></pre></td></tr></table></figure></li><li><p>global load</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">%param = util.global.load @param : tensor<?x2xf32></span><br><span class="line"></span><br><span class="line">转换成:</span><br><span class="line">%param = util.global.load @param : tensor<?x2xf32></span><br><span class="line">%param__d0 = util.global.load @param__d0 : index</span><br><span class="line">%<span class="number">0</span> = flow.tensor.tie_shape %param : tensor<?x2xf32>{%param__d0}</span><br></pre></td></tr></table></figure></li><li><p>global store</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">util.global.store %<span class="number">3</span>, @param : tensor<?x2xf32></span><br><span class="line"></span><br><span class="line">转换成:</span><br><span class="line">util.global.store %<span class="number">4</span>, @param : tensor<?x2xf32></span><br><span class="line">util.global.store %<span class="number">3</span>, @param__d0 : index</span><br></pre></td></tr></table></figure></li></ul></li><li><p>buildGlobalOptimizationPassPipeline</p><ul><li><p>IREE::Util::createSimplifyGlobalAccessesPass</p><p>这个pass主要做这几件事:</p><ul><li><p>将不可变global tensor的load提前到了block的开头,将globaltensor的store安全地挪到block的结尾。</p></li><li><p>进行以下化简:</p><ul><li><p>如果load after store,则把load直接替换成store的source。比如,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">store %<span class="number">0</span>, @p</span><br><span class="line">%<span class="number">1</span> = load @p</span><br><span class="line"><span class="keyword">return</span> %<span class="number">1</span></span><br></pre></td></tr></table></figure><p>会被转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">store %<span class="number">0</span>, @p</span><br><span class="line"><span class="keyword">return</span> %<span class="number">0</span></span><br></pre></td></tr></table></figure></li><li><p>如果store after store,则直接消除前一个store</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">store %<span class="number">0</span>, @p</span><br><span class="line">store %<span class="number">1</span>, @p</span><br></pre></td></tr></table></figure><p>会被转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">store %<span class="number">1</span>, @p</span><br></pre></td></tr></table></figure></li><li><p>如果load after load,则消除后一个load</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">0</span> = load @p</span><br><span class="line">%<span class="number">1</span> = load @p</span><br><span class="line"><span class="keyword">return</span> %<span class="number">1</span></span><br></pre></td></tr></table></figure><p>会被转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">%<span class="number">0</span> = load @p</span><br><span class="line"><span class="keyword">return</span> %<span class="number">0</span></span><br></pre></td></tr></table></figure></li></ul></li></ul><p>一个完整的例子:</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line">func.func <span class="keyword">private</span> @<span class="built_in">add</span>(!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub}</span><br><span class="line">util.global <span class="keyword">private</span> <span class="keyword">mutable</span> @param0 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line">util.global <span class="keyword">private</span> @param1 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line">func.func @<span class="built_in">run</span>(%arg0: !hal.buffer_view) attributes {iree.abi.stub} {</span><br><span class="line"> %param0 = util.global.load @param0 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">export</span> %param0 : tensor<<span class="number">1</span>x2xf32> -> !hal.buffer_view</span><br><span class="line"> %<span class="number">1</span> = call @<span class="built_in">add</span>(%<span class="number">0</span>, %arg0) : (!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view</span><br><span class="line"> %<span class="number">2</span> = hal.tensor.<span class="keyword">import</span> %<span class="number">1</span> : !hal.buffer_view -> tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> util.global.store %<span class="number">2</span>, @param0 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> %param0_0 = util.global.load @param0 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> %param1 = util.global.load @param1 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> %<span class="number">3</span> = hal.tensor.<span class="keyword">export</span> %param0_0 : tensor<<span class="number">1</span>x2xf32> -> !hal.buffer_view</span><br><span class="line"> %<span class="number">4</span> = hal.tensor.<span class="keyword">export</span> %param1 : tensor<<span class="number">1</span>x2xf32> -> !hal.buffer_view</span><br><span class="line"> %<span class="number">5</span> = call @<span class="built_in">add</span>(%<span class="number">3</span>, %<span class="number">4</span>) : (!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view</span><br><span class="line"> %<span class="number">6</span> = hal.tensor.<span class="keyword">import</span> %<span class="number">5</span> : !hal.buffer_view -> tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> util.global.store %<span class="number">6</span>, @param0 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line">func.func <span class="keyword">private</span> @<span class="built_in">add</span>(!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub}</span><br><span class="line"> util.global <span class="keyword">private</span> <span class="keyword">mutable</span> @param0 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> util.global <span class="keyword">private</span> @param1 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> func.func @<span class="built_in">run</span>(%arg0: !hal.buffer_view) attributes {iree.abi.stub} {</span><br><span class="line"> %param0 = util.global.load @param0 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> %param1 = util.global.load @param1 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">export</span> %param0 : tensor<<span class="number">1</span>x2xf32> -> !hal.buffer_view</span><br><span class="line"> %<span class="number">1</span> = call @<span class="built_in">add</span>(%<span class="number">0</span>, %arg0) : (!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view</span><br><span class="line"> %<span class="number">2</span> = hal.tensor.<span class="keyword">import</span> %<span class="number">1</span> : !hal.buffer_view -> tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> %<span class="number">3</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">2</span> : tensor<<span class="number">1</span>x2xf32> -> !hal.buffer_view</span><br><span class="line"> %<span class="number">4</span> = hal.tensor.<span class="keyword">export</span> %param1 : tensor<<span class="number">1</span>x2xf32> -> !hal.buffer_view</span><br><span class="line"> util.global.store %<span class="number">2</span>, @param0 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> %<span class="number">5</span> = call @<span class="built_in">add</span>(%<span class="number">3</span>, %<span class="number">4</span>) : (!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view</span><br><span class="line"> %<span class="number">6</span> = hal.tensor.<span class="keyword">import</span> %<span class="number">5</span> : !hal.buffer_view -> tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> util.global.store %<span class="number">6</span>, @param0 : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line"> }</span><br></pre></td></tr></table></figure><p>这个例子中将param1的load操作提前,并且将<code>%param0_0 = util.global.load @param0 : tensor<1x2xf32></code>直接替换为<code>%2</code>。</p></li><li><p>IREE::Util::createApplyPatternsPass</p><p>执行IREE::Util dialect ODS中定义的CanonicalizationPatterns,并执行block和跳转命令参数化简操作。</p><ul><li><p>block参数化简</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">br ^<span class="built_in">bb1</span>(%<span class="number">0</span>, %<span class="number">0</span> : index, index)</span><br><span class="line">^<span class="built_in">bb1</span>(%arg0: index, %arg1: index):</span><br><span class="line"> ...</span><br></pre></td></tr></table></figure><p>折叠相同的参数,化简为</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">br ^<span class="built_in">bb1</span>(%<span class="number">0</span> : index)</span><br><span class="line">^<span class="built_in">bb1</span>(%arg0: index): <span class="comment">// %arg1 remapped to %arg0</span></span><br><span class="line"> ...</span><br></pre></td></tr></table></figure></li><li><p>跳转命令参数消除</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%arg0: index) {</span><br><span class="line"> br ^<span class="built_in">bb1</span>(%arg0 : index)</span><br><span class="line"> ^<span class="built_in">bb1</span>(%<span class="number">0</span>: index):</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>消除参数后,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%arg0: index) {</span><br><span class="line"> br ^bb1</span><br><span class="line"> ^bb1: <span class="comment">// %0 remapped to %arg0</span></span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li></ul></li><li><p>IREE::Util::createFoldGlobalsPass</p><p>这个pass继续对global tensor的load和store操作进行优化,主要包括:</p><ul><li><p>内联常量store,比如</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">util.global <span class="keyword">mutable</span> @a : i32</span><br><span class="line">func.func @fool {</span><br><span class="line"> %c5 = arith.constant <span class="number">5</span> : i32</span><br><span class="line"> util.global.store %c5, @a : i32</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">util.global @a = <span class="number">5</span> : i32</span><br></pre></td></tr></table></figure></li><li><p>內联常量load,比如</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">util.global @a = <span class="number">5</span> : i32</span><br><span class="line">func.func @fool {</span><br><span class="line"> %<span class="number">1</span> = util.global.load @a : i32</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">func.func @fool {</span><br><span class="line"> %<span class="number">1</span> = arith.constant <span class="number">5</span> : i32</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>重命名互为链式的global tensor。</p></li><li><p>如果一个mutable globaltensor只在init函数中被store过,则将它修改为immutable。</p></li><li><p>删除没有load过的global tensor。</p></li><li><p>合并相同初始值的immutable global tensor。</p></li></ul></li><li><p>IREE::Util::createHoistIntoGlobalsPass</p></li></ul></li><li><p>IREE::Flow::createTensorPadToTensorInsertSlicePass</p><p>将<code>tensor.pad</code>转换为<code>linalg.fill</code> +<code>tensor.insert_slice</code>。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x1xf32></span><br><span class="line"> %padded = tensor.pad %<span class="number">0</span> low[<span class="number">1</span>, <span class="number">2</span>] high[<span class="number">3</span>, <span class="number">4</span>] {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg1: index, %arg2: index):</span><br><span class="line"> tensor.yield %cst : f32</span><br><span class="line"> } : tensor<<span class="number">1</span>x1xf32> to tensor<<span class="number">5</span>x7xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">export</span> %padded : tensor<<span class="number">5</span>x7xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">1</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换为,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x1xf32></span><br><span class="line"> %<span class="number">1</span> = tensor.<span class="built_in">empty</span>() : tensor<<span class="number">5</span>x7xf32></span><br><span class="line"> %<span class="number">2</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">1</span> : tensor<<span class="number">5</span>x7xf32>) -> tensor<<span class="number">5</span>x7xf32></span><br><span class="line"> %inserted_slice = tensor.insert_slice %<span class="number">0</span> into %<span class="number">2</span>[<span class="number">1</span>, <span class="number">2</span>] [<span class="number">1</span>, <span class="number">1</span>] [<span class="number">1</span>, <span class="number">1</span>] : tensor<<span class="number">1</span>x1xf32> into tensor<<span class="number">5</span>x7xf32></span><br><span class="line"> %<span class="number">3</span> = hal.tensor.<span class="keyword">export</span> %inserted_slice : tensor<<span class="number">5</span>x7xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">3</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>mlir::createConvertElementwiseToLinalgPass</p><p>把elementwise算子(带有Elementwise traits的op)转换成linalg genericop,方便后续对elementwise op做算子融合。arith dialect和mathdialect的op都是Elementwise的,所以实际上这个pass会把arith dialect和mathdialect lower到linalg dialect。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">2</span>x3xf32></span><br><span class="line"> %<span class="number">1</span> = arith.addf %<span class="number">0</span>, %<span class="number">0</span> : tensor<<span class="number">2</span>x3xf32></span><br><span class="line"> %<span class="number">2</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">1</span> : tensor<<span class="number">2</span>x3xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">2</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">2</span>x3xf32></span><br><span class="line"> %<span class="number">1</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">0</span> : tensor<<span class="number">2</span>x3xf32>, tensor<<span class="number">2</span>x3xf32>) <span class="built_in">outs</span>(%<span class="number">0</span> : tensor<<span class="number">2</span>x3xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">3</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> linalg.yield %<span class="number">3</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">2</span>x3xf32></span><br><span class="line"> %<span class="number">2</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">1</span> : tensor<<span class="number">2</span>x3xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">2</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>mlir::createLinalgFoldUnitExtentDimsPass</p><p>消除长度为1的维度或者循环。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x3xf32></span><br><span class="line"> %<span class="number">1</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">0</span> : tensor<<span class="number">1</span>x3xf32>) <span class="built_in">outs</span>(%<span class="number">0</span> : tensor<<span class="number">1</span>x3xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %out: f32):</span><br><span class="line"> %<span class="number">3</span> = arith.addf %in, %in : f32</span><br><span class="line"> linalg.yield %<span class="number">3</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">1</span>x3xf32></span><br><span class="line"> %<span class="number">2</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">1</span> : tensor<<span class="number">1</span>x3xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">2</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">1</span>x3xf32></span><br><span class="line"> %collapsed = tensor.collapse_shape %<span class="number">0</span> [[<span class="number">0</span>, <span class="number">1</span>]] : tensor<<span class="number">1</span>x3xf32> into tensor<<span class="number">3</span>xf32></span><br><span class="line"> %collapsed_0 = tensor.collapse_shape %<span class="number">0</span> [[<span class="number">0</span>, <span class="number">1</span>]] : tensor<<span class="number">1</span>x3xf32> into tensor<<span class="number">3</span>xf32></span><br><span class="line"> %<span class="number">1</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%collapsed : tensor<<span class="number">3</span>xf32>) <span class="built_in">outs</span>(%collapsed_0 : tensor<<span class="number">3</span>xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %out: f32):</span><br><span class="line"> %<span class="number">3</span> = arith.addf %in, %in : f32</span><br><span class="line"> linalg.yield %<span class="number">3</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">3</span>xf32></span><br><span class="line"> %expanded = tensor.expand_shape %<span class="number">1</span> [[<span class="number">0</span>, <span class="number">1</span>]] : tensor<<span class="number">3</span>xf32> into tensor<<span class="number">1</span>x3xf32></span><br><span class="line"> %<span class="number">2</span> = hal.tensor.<span class="keyword">export</span> %expanded : tensor<<span class="number">1</span>x3xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">2</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>可以看到其中的<code>linalg.generic</code>由2层循环缩减成了单层循环。</p></li><li><p>createInterchangeGenericOpsPass</p><p>循环维度变换。将reduction循环维度交换到最内层,相应的parallel循环维度被交换到外层。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// sum(%arg0: tensor<2x3xf32>, 0) -> tensor<3xf32></span></span><br><span class="line">func.func @<span class="built_in">foo</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">2</span>x3xf32></span><br><span class="line"> %<span class="number">1</span> = tensor.<span class="built_in">empty</span>() : tensor<<span class="number">3</span>xf32></span><br><span class="line"> %<span class="number">2</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">1</span> : tensor<<span class="number">3</span>xf32>) -> tensor<<span class="number">3</span>xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d1)>], iterator_types = [<span class="string">"reduction"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">0</span> : tensor<<span class="number">2</span>x3xf32>) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">3</span>xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %out: f32):</span><br><span class="line"> %<span class="number">5</span> = arith.addf %in, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">5</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">3</span>xf32></span><br><span class="line"> %<span class="number">4</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">3</span> : tensor<<span class="number">3</span>xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">4</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>交换循环之后转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">foo</span>(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">2</span>x3xf32></span><br><span class="line"> %<span class="number">1</span> = tensor.<span class="built_in">empty</span>() : tensor<<span class="number">3</span>xf32></span><br><span class="line"> %<span class="number">2</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">1</span> : tensor<<span class="number">3</span>xf32>) -> tensor<<span class="number">3</span>xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d1, d0)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">0</span> : tensor<<span class="number">2</span>x3xf32>) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">3</span>xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %out: f32):</span><br><span class="line"> %<span class="number">5</span> = arith.addf %in, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">5</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">3</span>xf32></span><br><span class="line"> %<span class="number">4</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">3</span> : tensor<<span class="number">3</span>xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">4</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>memref::createResolveShapedTypeResultDimsPass</p></li><li><p>mlir::createCanonicalizerPass</p></li><li><p>mlir::createCSEPass</p></li><li><p>createFusionOfTensorOpsPass</p><p>主要做elementwise的算子融合,其次也会将<code>tensor.expand_shape</code>转换成<code>linalg generic op</code>,方便进行算子融合。</p><p>elementwise算子融合的条件:</p><ul><li>producer和comsumer都是linalg generic op,且都为tensor语义。</li><li>producer只有一个use。</li><li>producer所有维度的迭代类型都是parallel,consumer的indexmap必须和producer具有相同的循环嵌套层数。</li><li>producer结果的indexmap必须是Permutation,即结果的每个元素有且仅store一次(输出是pointwise的)。</li><li>consumer可以包含reduction迭代类型,但需要保证融合后输入的indexmap可以覆盖每一个迭代维度,理由是如果缺失就无法确定该维度的循环边界。</li></ul><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// reduce(mul(arg0, arg1), 0)</span></span><br><span class="line"><span class="comment">// for (int d0 = 0; d0 < n; ++d0) {</span></span><br><span class="line"><span class="comment">// temp[d0] = arg0[d0] * arg1[d0];</span></span><br><span class="line"><span class="comment">// }</span></span><br><span class="line"><span class="comment">// result = 0;</span></span><br><span class="line"><span class="comment">// for (int d0 = 0; d0 < n; ++d0) {</span></span><br><span class="line"><span class="comment">// result += temp[d0];</span></span><br><span class="line"><span class="comment">// }</span></span><br><span class="line">func.func @<span class="built_in">foo</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">2</span> = tensor.<span class="built_in">empty</span>() : tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">2</span>xf32>, tensor<<span class="number">2</span>xf32>) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">2</span>xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">8</span> = arith.mulf %in, %in_0 : f32</span><br><span class="line"> linalg.yield %<span class="number">8</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">4</span> = tensor.<span class="built_in">empty</span>() : tensor<f32></span><br><span class="line"> %<span class="number">5</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">4</span> : tensor<f32>) -> tensor<f32></span><br><span class="line"> %<span class="number">6</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> ()>], iterator_types = [<span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">3</span> : tensor<<span class="number">2</span>xf32>) <span class="built_in">outs</span>(%<span class="number">5</span> : tensor<f32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %out: f32):</span><br><span class="line"> %<span class="number">8</span> = arith.addf %in, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">8</span> : f32</span><br><span class="line"> } -> tensor<f32></span><br><span class="line"> %<span class="number">7</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">6</span> : tensor<f32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">7</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>融合mul和reduce之后转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// result = 0;</span></span><br><span class="line"><span class="comment">// for (int d0 = 0; d0 < n; ++d0) {</span></span><br><span class="line"><span class="comment">// result += arg0[d0] * arg1[d0];</span></span><br><span class="line"><span class="comment">// }</span></span><br><span class="line">func.func @<span class="built_in">foo</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">2</span> = tensor.<span class="built_in">empty</span>() : tensor<f32></span><br><span class="line"> %<span class="number">3</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<f32>) -> tensor<f32></span><br><span class="line"> %<span class="number">4</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> ()>], iterator_types = [<span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">2</span>xf32>, tensor<<span class="number">2</span>xf32>) <span class="built_in">outs</span>(%<span class="number">3</span> : tensor<f32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">6</span> = arith.mulf %in, %in_0 : f32</span><br><span class="line"> %<span class="number">7</span> = arith.addf %<span class="number">6</span>, %out : f32</span><br><span class="line"> linalg.yield %<span class="number">7</span> : f32</span><br><span class="line"> } -> tensor<f32></span><br><span class="line"> %<span class="number">5</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">4</span> : tensor<f32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">5</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>mlir::createLinalgDetensorizePass</p><p>将0-D Tensor转换为它的基础元素类型。</p></li><li><p>mlir::createCanonicalizerPass</p></li><li><p>mlir::createCSEPass</p></li><li><p>createSplitReductionPass</p><p>将matmul和topk的单次reduce分成两次reduce操作(一次batchmatmul和一次add)。默认不开启,设置--iree-flow-split-matmul-reduction选项>=2可开启。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">256</span>x256xf32></span><br><span class="line"> %<span class="number">2</span> = linalg.init_tensor [<span class="number">128</span>, <span class="number">256</span>] : tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">128</span>x256xf32>) -> tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">4</span> = linalg.matmul <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">128</span>x256xf32>, tensor<<span class="number">256</span>x256xf32>) <span class="built_in">outs</span>(%<span class="number">3</span> : tensor<<span class="number">128</span>x256xf32>) -> tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">5</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">4</span> : tensor<<span class="number">128</span>x256xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">5</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>--iree-flow-split-matmul-reduction=2转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">256</span>x256xf32></span><br><span class="line"> %<span class="number">2</span> = linalg.init_tensor [<span class="number">128</span>, <span class="number">256</span>] : tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">128</span>x256xf32>) -> tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">4</span> = tensor.expand_shape %<span class="number">0</span> [[<span class="number">0</span>], [<span class="number">1</span>, <span class="number">2</span>]] : tensor<<span class="number">128</span>x256xf32> into tensor<<span class="number">128</span>x2x128xf32></span><br><span class="line"> %<span class="number">5</span> = tensor.expand_shape %<span class="number">1</span> [[<span class="number">0</span>, <span class="number">1</span>], [<span class="number">2</span>]] : tensor<<span class="number">256</span>x256xf32> into tensor<<span class="number">2</span>x128x256xf32></span><br><span class="line"> %<span class="number">6</span> = linalg.init_tensor [<span class="number">2</span>, <span class="number">128</span>, <span class="number">256</span>] : tensor<<span class="number">2</span>x128x256xf32></span><br><span class="line"> %<span class="number">7</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">6</span> : tensor<<span class="number">2</span>x128x256xf32>) -> tensor<<span class="number">2</span>x128x256xf32></span><br><span class="line"> %<span class="number">8</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d1, d0, d3)>, <span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d3, d2)>, <span class="built_in">affine_map</span><(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>, <span class="string">"reduction"</span>]} <span class="built_in">ins</span>(%<span class="number">4</span>, %<span class="number">5</span> : tensor<<span class="number">128</span>x2x128xf32>, tensor<<span class="number">2</span>x128x256xf32>) <span class="built_in">outs</span>(%<span class="number">7</span> : tensor<<span class="number">2</span>x128x256xf32>) attrs = {__internal_linalg_transform__ = <span class="string">"SPLIT"</span>, linalg.memoized_indexing_maps = [<span class="built_in">affine_map</span><(d0, d1, d2) -> (d0, d2)>, <span class="built_in">affine_map</span><(d0, d1, d2) -> (d2, d1)>, <span class="built_in">affine_map</span><(d0, d1, d2) -> (d0, d1)>]} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg2: f32, %arg3: f32, %arg4: f32):</span><br><span class="line"> %<span class="number">11</span> = arith.mulf %arg2, %arg3 : f32</span><br><span class="line"> %<span class="number">12</span> = arith.addf %arg4, %<span class="number">11</span> : f32</span><br><span class="line"> linalg.yield %<span class="number">12</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">2</span>x128x256xf32></span><br><span class="line"> %<span class="number">9</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1, d2) -> (d0, d1, d2)>, <span class="built_in">affine_map</span><(d0, d1, d2) -> (d1, d2)>], iterator_types = [<span class="string">"reduction"</span>, <span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">8</span> : tensor<<span class="number">2</span>x128x256xf32>) <span class="built_in">outs</span>(%<span class="number">3</span> : tensor<<span class="number">128</span>x256xf32>) attrs = {__internal_linalg_transform__ = <span class="string">"SPLIT"</span>} {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg2: f32, %arg3: f32):</span><br><span class="line"> %<span class="number">11</span> = arith.addf %arg2, %arg3 : f32</span><br><span class="line"> linalg.yield %<span class="number">11</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">128</span>x256xf32></span><br><span class="line"> %<span class="number">10</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">9</span> : tensor<<span class="number">128</span>x256xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">10</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createInterchangeGenericOpsPass</p><p>循环维度变换。将reduction循环维度交换到最内层,相应的parallel循环维度被交换到外层。</p></li><li><p>createInterchangeTransposeGenericOpsPass</p><p>当输入indexing map是permutation时,交换循环维度使得输入的indexingmap是identity的,其作用是使得输入尽可能变成连续访存。</p></li><li><p>createDispatchWithTransformDialect</p><p>根据transform dialect对算子进行调度和派遣,需要另外加载一个transformdialect的module文件,默认不做该变换。transformdialect定义了一套调度规则,用于引导目标IR进行变换,比如循环展开、tiling等。</p></li><li><p>createFormDispatchRegionsPass</p><p>以包含reduction loop的linalg op或named linalgop为中心(root),按一定规则合并producers和comsumers,划分出dispatchregion子图。dispatch region是IREE中的原子执行单元,dispatchregion内部可以直接复用输入和输出的内存,从而避免了内部的内存分配操作,内存分配只发生在dispatchregion的边界,同时dispatch region之间会自动插入同步操作。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">predict</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">2</span>x10xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">10</span>x5xf32></span><br><span class="line"> %<span class="number">2</span> = hal.tensor.<span class="keyword">import</span> %arg2 : !hal.buffer_view -> tensor<<span class="number">5</span>xf32></span><br><span class="line"> %<span class="number">3</span> = tensor.<span class="built_in">empty</span>() : tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> %<span class="number">4</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">3</span> : tensor<<span class="number">2</span>x5xf32>) -> tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> %<span class="number">5</span> = linalg.matmul <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">2</span>x10xf32>, tensor<<span class="number">10</span>x5xf32>) <span class="built_in">outs</span>(%<span class="number">4</span> : tensor<<span class="number">2</span>x5xf32>) -> tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> %<span class="number">6</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">5</span>, %<span class="number">2</span> : tensor<<span class="number">2</span>x5xf32>, tensor<<span class="number">5</span>xf32>) <span class="built_in">outs</span>(%<span class="number">3</span> : tensor<<span class="number">2</span>x5xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">8</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> linalg.yield %<span class="number">8</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> %<span class="number">7</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">6</span> : tensor<<span class="number">2</span>x5xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">7</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">predict</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">2</span>x10xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">10</span>x5xf32></span><br><span class="line"> %<span class="number">2</span> = hal.tensor.<span class="keyword">import</span> %arg2 : !hal.buffer_view -> tensor<<span class="number">5</span>xf32></span><br><span class="line"> %<span class="number">3</span> = tensor.<span class="built_in">empty</span>() : tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> %<span class="number">4</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">3</span> : tensor<<span class="number">2</span>x5xf32>) -> tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %c2 = arith.constant <span class="number">2</span> : index</span><br><span class="line"> %c1_0 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %<span class="number">5</span> = affine.apply <span class="built_in">affine_map</span><()[s0, s1, s2] -> ((s1 - s0) ceildiv s2)>()[%c0, %c2, %c1_0]</span><br><span class="line"> %c0_1 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %c5 = arith.constant <span class="number">5</span> : index</span><br><span class="line"> %c1_2 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %<span class="number">6</span> = affine.apply <span class="built_in">affine_map</span><()[s0, s1, s2] -> ((s1 - s0) ceildiv s2)>()[%c0_1, %c5, %c1_2]</span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.region[%<span class="number">5</span>, %<span class="number">6</span>] -> (tensor<<span class="number">2</span>x5xf32>) {</span><br><span class="line"> %<span class="number">9</span> = linalg.matmul <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">2</span>x10xf32>, tensor<<span class="number">10</span>x5xf32>) <span class="built_in">outs</span>(%<span class="number">4</span> : tensor<<span class="number">2</span>x5xf32>) -> tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> %<span class="number">10</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">9</span>, %<span class="number">2</span> : tensor<<span class="number">2</span>x5xf32>, tensor<<span class="number">5</span>xf32>) <span class="built_in">outs</span>(%<span class="number">3</span> : tensor<<span class="number">2</span>x5xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_3: f32, %out: f32):</span><br><span class="line"> %<span class="number">11</span> = arith.addf %in, %in_3 : f32</span><br><span class="line"> linalg.yield %<span class="number">11</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> flow.<span class="keyword">return</span> %<span class="number">10</span> : tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> } <span class="built_in">count</span>(%arg3: index, %arg4: index) -> (index, index, index) {</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4</span><br><span class="line"> flow.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> %<span class="number">8</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">7</span> : tensor<<span class="number">2</span>x5xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">8</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createFormDispatchWorkgroupsPass</p><p>将dispatch region转换成dispatch workgroup的形式,并将cloneable的op(比如<code>tensor.fill</code>、<code>tensor.empty</code>等)拷贝到workgroup中。如果在linalg层做了tiling,该pass也会把tiling引入的<code>tensor.extract_slice</code>和<code>tensor.insert_slice</code>尽可能转换成<code>flow.tensor.slice</code>和<code>flow.tensor.update</code>,转换不了的后续再转换成<code>flow.dispatch.tensor.load</code>和<code>flow.dispatch.tensor.store</code>。这里上一步的结果会被转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">predict</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c2 = arith.constant <span class="number">2</span> : index</span><br><span class="line"> %c5 = arith.constant <span class="number">5</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">2</span>x10xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">10</span>x5xf32></span><br><span class="line"> %<span class="number">2</span> = hal.tensor.<span class="keyword">import</span> %arg2 : !hal.buffer_view -> tensor<<span class="number">5</span>xf32></span><br><span class="line"> %<span class="number">3</span> = flow.dispatch.workgroups[%c2, %c5](%<span class="number">0</span>, %<span class="number">1</span>, %<span class="number">2</span>) : (tensor<<span class="number">2</span>x10xf32>, tensor<<span class="number">10</span>x5xf32>, tensor<<span class="number">5</span>xf32>) -> tensor<<span class="number">2</span>x5xf32> =</span><br><span class="line"> (%arg3: !flow.dispatch.tensor<readonly:tensor<<span class="number">2</span>x10xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>x5xf32>>, %arg5: !flow.dispatch.tensor<readonly:tensor<<span class="number">5</span>xf32>>, %arg6: !flow.dispatch.tensor<writeonly:tensor<<span class="number">2</span>x5xf32>>) {</span><br><span class="line"> %cst = arith.constant <span class="number">0.000000e+00</span> : f32</span><br><span class="line"> %<span class="number">5</span> = flow.dispatch.tensor.load %arg3, offsets = [<span class="number">0</span>, <span class="number">0</span>], sizes = [<span class="number">2</span>, <span class="number">10</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">2</span>x10xf32>> -> tensor<<span class="number">2</span>x10xf32></span><br><span class="line"> %<span class="number">6</span> = flow.dispatch.tensor.load %arg4, offsets = [<span class="number">0</span>, <span class="number">0</span>], sizes = [<span class="number">10</span>, <span class="number">5</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">10</span>x5xf32>> -> tensor<<span class="number">10</span>x5xf32></span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tensor.load %arg5, offsets = [<span class="number">0</span>], sizes = [<span class="number">5</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">5</span>xf32>> -> tensor<<span class="number">5</span>xf32></span><br><span class="line"> %<span class="number">8</span> = tensor.<span class="built_in">empty</span>() : tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> %<span class="number">9</span> = linalg.fill <span class="built_in">ins</span>(%cst : f32) <span class="built_in">outs</span>(%<span class="number">8</span> : tensor<<span class="number">2</span>x5xf32>) -> tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> %<span class="number">10</span> = linalg.matmul <span class="built_in">ins</span>(%<span class="number">5</span>, %<span class="number">6</span> : tensor<<span class="number">2</span>x10xf32>, tensor<<span class="number">10</span>x5xf32>) <span class="built_in">outs</span>(%<span class="number">9</span> : tensor<<span class="number">2</span>x5xf32>) -> tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> %<span class="number">11</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d1)>, <span class="built_in">affine_map</span><(d0, d1) -> (d0, d1)>], iterator_types = [<span class="string">"parallel"</span>, <span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">10</span>, %<span class="number">7</span> : tensor<<span class="number">2</span>x5xf32>, tensor<<span class="number">5</span>xf32>) <span class="built_in">outs</span>(%<span class="number">8</span> : tensor<<span class="number">2</span>x5xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">12</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> linalg.yield %<span class="number">12</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">2</span>x5xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">11</span>, %arg6, offsets = [<span class="number">0</span>, <span class="number">0</span>], sizes = [<span class="number">2</span>, <span class="number">5</span>], strides = [<span class="number">1</span>, <span class="number">1</span>] : tensor<<span class="number">2</span>x5xf32> -> !flow.dispatch.tensor<writeonly:tensor<<span class="number">2</span>x5xf32>></span><br><span class="line"> flow.<span class="keyword">return</span></span><br><span class="line"> } <span class="built_in">count</span>(%arg3: index, %arg4: index) -> (index, index, index) {</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4</span><br><span class="line"> flow.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> %<span class="number">4</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">3</span> : tensor<<span class="number">2</span>x5xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">4</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createCaptureDispatchDynamicDimsPass</p><p>由于<code>flow.dispatch.workgroups</code>的参数中动态形状tensor被替换成了<code>!flow.dispatch.tensor</code>和相应的动态维度index,该pass捕获workgroups参数中的动态维度index,插入<code>flow.dispatch.tie_shape</code>将参数中的动态维度index和<code>!flow.dispatch.tensor</code>进行绑定。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// func.func @test(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {</span></span><br><span class="line"><span class="comment">// %0 = mhlo.add %arg0, %arg1 : tensor<?xf32></span></span><br><span class="line"><span class="comment">// return %0 : tensor<?xf32></span></span><br><span class="line"><span class="comment">// }</span></span><br><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<?xf32>{%<span class="number">0</span>}</span><br><span class="line"> %<span class="number">2</span> = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">3</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<?xf32>{%<span class="number">2</span>}</span><br><span class="line"> %<span class="number">4</span> = affine.apply <span class="built_in">affine_map</span><()[s0, s1, s2] -> ((s1 - s0) ceildiv s2)>()[%c0, %<span class="number">0</span>, %c1]</span><br><span class="line"> %<span class="number">5</span> = flow.dispatch.workgroups[%<span class="number">4</span>](%<span class="number">0</span>, %<span class="number">1</span>, %<span class="number">3</span>, %<span class="number">0</span>, %<span class="number">2</span>, %<span class="number">0</span>) : (index, tensor<?xf32>{%<span class="number">0</span>}, tensor<?xf32>{%<span class="number">2</span>}, index, index, index) -> tensor<?xf32>{%<span class="number">0</span>} =</span><br><span class="line"> (%arg2: index, %arg3: !flow.dispatch.tensor<readonly:tensor<?xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<?xf32>>, %arg5: index, %arg6: index, %arg7: index, %arg8: !flow.dispatch.tensor<writeonly:tensor<?xf32>>) {</span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tensor.load %arg3, offsets = [<span class="number">0</span>], sizes = [%arg7], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<?xf32>>{%arg7} -> tensor<?xf32></span><br><span class="line"> %<span class="number">8</span> = flow.dispatch.tensor.load %arg4, offsets = [<span class="number">0</span>], sizes = [%arg6], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<?xf32>>{%arg6} -> tensor<?xf32></span><br><span class="line"> %<span class="number">9</span> = tensor.<span class="built_in">empty</span>(%arg7) : tensor<?xf32></span><br><span class="line"> %<span class="number">10</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">7</span>, %<span class="number">8</span> : tensor<?xf32>, tensor<?xf32>) <span class="built_in">outs</span>(%<span class="number">9</span> : tensor<?xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">11</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> linalg.yield %<span class="number">11</span> : f32</span><br><span class="line"> } -> tensor<?xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">10</span>, %arg8, offsets = [<span class="number">0</span>], sizes = [%arg7], strides = [<span class="number">1</span>] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%arg7}</span><br><span class="line"> flow.<span class="keyword">return</span></span><br><span class="line"> } <span class="built_in">count</span>(%arg2: index) -> (index, index, index) {</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg2</span><br><span class="line"> flow.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> %<span class="number">6</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">5</span> : tensor<?xf32>{%<span class="number">0</span>} -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">6</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>会被转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<?xf32>{%<span class="number">0</span>}</span><br><span class="line"> %<span class="number">2</span> = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[<span class="number">0</span>] : index</span><br><span class="line"> %<span class="number">3</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<?xf32>{%<span class="number">2</span>}</span><br><span class="line"> %<span class="number">4</span> = affine.apply <span class="built_in">affine_map</span><()[s0, s1, s2] -> ((s1 - s0) ceildiv s2)>()[%c0, %<span class="number">0</span>, %c1]</span><br><span class="line"> %<span class="number">5</span> = flow.dispatch.workgroups[%<span class="number">4</span>](%<span class="number">0</span>, %<span class="number">1</span>, %<span class="number">3</span>, %<span class="number">0</span>, %<span class="number">2</span>, %<span class="number">0</span>) : (index, tensor<?xf32>{%<span class="number">0</span>}, tensor<?xf32>{%<span class="number">2</span>}, index, index, index) -> tensor<?xf32>{%<span class="number">0</span>} =</span><br><span class="line"> (%arg2: index, %arg3: !flow.dispatch.tensor<readonly:tensor<?xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<?xf32>>, %arg5: index, %arg6: index, %arg7: index, %arg8: !flow.dispatch.tensor<writeonly:tensor<?xf32>>) {</span><br><span class="line"> %<span class="number">7</span> = flow.dispatch.tie_shape %arg3 : !flow.dispatch.tensor<readonly:tensor<?xf32>>{%arg7}</span><br><span class="line"> %<span class="number">8</span> = flow.dispatch.tie_shape %arg4 : !flow.dispatch.tensor<readonly:tensor<?xf32>>{%arg6}</span><br><span class="line"> %<span class="number">9</span> = flow.dispatch.tie_shape %arg8 : !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%arg7}</span><br><span class="line"> %<span class="number">10</span> = flow.dispatch.tensor.load %<span class="number">7</span>, offsets = [<span class="number">0</span>], sizes = [%arg7], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<?xf32>>{%arg7} -> tensor<?xf32></span><br><span class="line"> %<span class="number">11</span> = flow.dispatch.tensor.load %<span class="number">8</span>, offsets = [<span class="number">0</span>], sizes = [%arg6], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<?xf32>>{%arg6} -> tensor<?xf32></span><br><span class="line"> %<span class="number">12</span> = tensor.<span class="built_in">empty</span>(%arg7) : tensor<?xf32></span><br><span class="line"> %<span class="number">13</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">10</span>, %<span class="number">11</span> : tensor<?xf32>, tensor<?xf32>) <span class="built_in">outs</span>(%<span class="number">12</span> : tensor<?xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">14</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> linalg.yield %<span class="number">14</span> : f32</span><br><span class="line"> } -> tensor<?xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">13</span>, %<span class="number">9</span>, offsets = [<span class="number">0</span>], sizes = [%arg7], strides = [<span class="number">1</span>] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%arg7}</span><br><span class="line"> flow.<span class="keyword">return</span></span><br><span class="line"> } <span class="built_in">count</span>(%arg2: index) -> (index, index, index) {</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg2</span><br><span class="line"> flow.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> %<span class="number">6</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">5</span> : tensor<?xf32>{%<span class="number">0</span>} -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">6</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>mlir::createCanonicalizerPass</p></li><li><p>createCSEPass</p></li><li><p>createInitializeEmptyTensorsPass</p><p>如果<code>tensor.empty</code> op的user中存在非linalg或IREE LinalgExtop,则把该<code>tensor.empty</code>op转换成<code>flow.tensor.empty</code>或<code>flow.tensor.splat</code>op。</p></li><li><p>IREE::Flow::createOutlineDispatchRegionsPass</p><p>把每个dispatch region转换成<code>flow.executable</code> +<code>flow.dispatch</code> op。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c2 = arith.constant <span class="number">2</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">2</span> = flow.dispatch.workgroups[%c2](%<span class="number">0</span>, %<span class="number">1</span>) : (tensor<<span class="number">2</span>xf32>, tensor<<span class="number">2</span>xf32>) -> tensor<<span class="number">2</span>xf32> =</span><br><span class="line"> (%arg2: !flow.dispatch.tensor<readonly:tensor<<span class="number">2</span>xf32>>, %arg3: !flow.dispatch.tensor<readonly:tensor<<span class="number">2</span>xf32>>, %arg4: !flow.dispatch.tensor<writeonly:tensor<<span class="number">2</span>xf32>>) {</span><br><span class="line"> %<span class="number">4</span> = flow.dispatch.tensor.load %arg2, offsets = [<span class="number">0</span>], sizes = [<span class="number">2</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">2</span>xf32>> -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">5</span> = flow.dispatch.tensor.load %arg3, offsets = [<span class="number">0</span>], sizes = [<span class="number">2</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">2</span>xf32>> -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">6</span> = tensor.<span class="built_in">empty</span>() : tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">7</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">4</span>, %<span class="number">5</span> : tensor<<span class="number">2</span>xf32>, tensor<<span class="number">2</span>xf32>) <span class="built_in">outs</span>(%<span class="number">6</span> : tensor<<span class="number">2</span>xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">8</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> linalg.yield %<span class="number">8</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">7</span>, %arg4, offsets = [<span class="number">0</span>], sizes = [<span class="number">2</span>], strides = [<span class="number">1</span>] : tensor<<span class="number">2</span>xf32> -> !flow.dispatch.tensor<writeonly:tensor<<span class="number">2</span>xf32>></span><br><span class="line"> flow.<span class="keyword">return</span></span><br><span class="line"> } <span class="built_in">count</span>(%arg2: index) -> (index, index, index) {</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg2</span><br><span class="line"> flow.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line"> }</span><br><span class="line"> %<span class="number">3</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">2</span> : tensor<<span class="number">2</span>xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">3</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br></pre></td><td class="code"><pre><span class="line">flow.executable <span class="keyword">private</span> @test_dispatch_0 {</span><br><span class="line"> flow.executable.<span class="keyword">export</span> <span class="keyword">public</span> @<span class="function">test_dispatch_0_generic_2 <span class="title">workgroups</span><span class="params">(%arg0: index)</span> -> <span class="params">(index, index, index)</span> </span>{</span><br><span class="line"> %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0</span><br><span class="line"> flow.<span class="keyword">return</span> %x, %y, %z : index, index, index</span><br><span class="line">}</span><br><span class="line">builtin.<span class="keyword">module</span> {</span><br><span class="line"> func.func @<span class="built_in">test_dispatch_0_generic_2</span>(%arg0: !flow.dispatch.tensor<readonly:tensor<<span class="number">2</span>xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<<span class="number">2</span>xf32>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<<span class="number">2</span>xf32>>) {</span><br><span class="line"> %<span class="number">0</span> = flow.dispatch.tensor.load %arg0, offsets = [<span class="number">0</span>], sizes = [<span class="number">2</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">2</span>xf32>> -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">1</span> = flow.dispatch.tensor.load %arg1, offsets = [<span class="number">0</span>], sizes = [<span class="number">2</span>], strides = [<span class="number">1</span>] : !flow.dispatch.tensor<readonly:tensor<<span class="number">2</span>xf32>> -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">2</span> = tensor.<span class="built_in">empty</span>() : tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">3</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">2</span>xf32>, tensor<<span class="number">2</span>xf32>) <span class="built_in">outs</span>(%<span class="number">2</span> : tensor<<span class="number">2</span>xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%in: f32, %in_0: f32, %out: f32):</span><br><span class="line"> %<span class="number">4</span> = arith.addf %in, %in_0 : f32</span><br><span class="line"> linalg.yield %<span class="number">4</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> flow.dispatch.tensor.store %<span class="number">3</span>, %arg2, offsets = [<span class="number">0</span>], sizes = [<span class="number">2</span>], strides = [<span class="number">1</span>] : tensor<<span class="number">2</span>xf32> -> !flow.dispatch.tensor<writeonly:tensor<<span class="number">2</span>xf32>></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %c2 = arith.constant <span class="number">2</span> : index</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">2</span> = flow.dispatch @test_dispatch_0::@test_dispatch_0_generic_2[%c2](%<span class="number">0</span>, %<span class="number">1</span>) : (tensor<<span class="number">2</span>xf32>, tensor<<span class="number">2</span>xf32>) -> tensor<<span class="number">2</span>xf32></span><br><span class="line"> %<span class="number">3</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">2</span> : tensor<<span class="number">2</span>xf32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">3</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>IREE::Util::createStripDebugOpsPass</p><p>消除DebugOnly op。</p></li><li><p>mlir::createCanonicalizerPass</p></li><li><p>IREE::Flow::createDeduplicateExecutablesPass</p><p>消除重复的<code>flow.executable</code>。</p></li><li><p>IREE::Flow::createInjectDispatchTracingPass</p><p>注入跟踪运行时dispatch函数输入和输出信息的op。默认不开启。</p></li><li><p>IREE::Flow::createCleanupTensorShapesPass</p><p>删除<code>flow.tensor.tie_shape</code>op,并确认module中不再包含<code>tensor.dim</code>和<code>tensor.rank</code>这两类形状查询op。</p></li><li><p>mlir::createCanonicalizerPass</p></li><li><p>mlir::createCSEPass</p></li><li><p>mlir::createCanonicalizerPass</p></li><li><p>mlir::createCSEPass</p></li><li><p>mlir::createSymbolDCEPass</p></li></ul>]]></content>
<summary type="html"><p>IREE
Flow::buildFlowTransformPassPipeline主要作用是执行一系列窥孔优化,比如1x1的conv2d转换成matmul、tiling、op
fusion等,最终将workload拆分成<code>flow.executable</code>。相关的passes及其作用如下。</p></summary>
<category term="DL Compiler" scheme="https://hjchen2.github.io/categories/DL-Compiler/"/>
<category term="Deep Learning Compiler" scheme="https://hjchen2.github.io/tags/Deep-Learning-Compiler/"/>
<category term="IREE" scheme="https://hjchen2.github.io/tags/IREE/"/>
</entry>
<entry>
<title>IREE编译流程解析(三)</title>
<link href="https://hjchen2.github.io/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B3/"/>
<id>https://hjchen2.github.io/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B3/</id>
<published>2023-01-04T12:30:12.000Z</published>
<updated>2023-02-17T11:31:47.491Z</updated>
<content type="html"><![CDATA[<p>IREEABI::TransformPassPipeline主要作用是将外部导入的接口和本module导出到外部的接口参数统一成标准标量类型或<code>hal.buffer_view</code>类型(<code>hal.buffer_view</code>对应tensor),包含以下几个passes。</p><span id="more"></span><ul><li><p>createWrapEntryPointsPass</p><p>给external func生成一个内部函数,函数中调用原始的externalfunc,同时将public func的函数体包装成一个新的函数,原publicfunc中调用该函数。该pass最终的目的是将外部导入的接口和本module导出到外部的接口参数统一成标准标量类型或<code>hal.buffer_view</code>(<code>hal.buffer_view</code>对应tensor类型)。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// external/imported func</span></span><br><span class="line">func.func <span class="keyword">private</span> @<span class="built_in">add</span>(tensor<f32>, tensor<f32>) -> tensor<f32></span><br><span class="line"></span><br><span class="line"><span class="comment">// public/exported func</span></span><br><span class="line">func.func @<span class="built_in">test</span>(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {</span><br><span class="line"> %<span class="number">0</span> = call @<span class="built_in">add</span>(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">0</span> : tensor<f32></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line">func.func <span class="keyword">private</span> @<span class="built_in">add</span>(!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub}</span><br><span class="line">func.func <span class="keyword">private</span> @_add(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">export</span> %arg0 : tensor<f32> -> !hal.buffer_view</span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">export</span> %arg1 : tensor<f32> -> !hal.buffer_view</span><br><span class="line"> %<span class="number">2</span> = call @<span class="built_in">add</span>(%<span class="number">0</span>, %<span class="number">1</span>) : (!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view</span><br><span class="line"> %<span class="number">3</span> = hal.tensor.<span class="keyword">import</span> %<span class="number">2</span> : !hal.buffer_view -> tensor<f32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">3</span> : tensor<f32></span><br><span class="line">}</span><br><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %<span class="number">0</span> = hal.tensor.<span class="keyword">import</span> %arg0 : !hal.buffer_view -> tensor<f32></span><br><span class="line"> %<span class="number">1</span> = hal.tensor.<span class="keyword">import</span> %arg1 : !hal.buffer_view -> tensor<f32></span><br><span class="line"> %<span class="number">2</span> = call @_test(%<span class="number">0</span>, %<span class="number">1</span>) : (tensor<f32>, tensor<f32>) -> tensor<f32></span><br><span class="line"> %<span class="number">3</span> = hal.tensor.<span class="keyword">export</span> %<span class="number">2</span> : tensor<f32> -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">3</span> : !hal.buffer_view</span><br><span class="line">}</span><br><span class="line">func.func <span class="keyword">private</span> @_test(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {</span><br><span class="line"> %<span class="number">0</span> = call @_add(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">0</span> : tensor<f32></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>mlir::createInlinerPass</p><p>将WrapEntryPointsPass中生成的wrap函数内联起来。最终转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">func.func <span class="keyword">private</span> @<span class="built_in">add</span>(!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub}</span><br><span class="line">func.func @<span class="built_in">test</span>(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {</span><br><span class="line"> %<span class="number">0</span> = call @<span class="built_in">add</span>(%arg0, %arg1) : (!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">0</span> : !hal.buffer_view</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>mlir::createCanonicalizerPass</p></li><li><p>mlir::createCSEPass</p></li><li><p>mlir::createSymbolDCEPass</p></li></ul>]]></content>
<summary type="html"><p>IREE
ABI::TransformPassPipeline主要作用是将外部导入的接口和本module导出到外部的接口参数统一成标准标量类型或<code>hal.buffer_view</code>类型(<code>hal.buffer_view</code>对应tensor),包含以下几个passes。</p></summary>
<category term="DL Compiler" scheme="https://hjchen2.github.io/categories/DL-Compiler/"/>
<category term="Deep Learning Compiler" scheme="https://hjchen2.github.io/tags/Deep-Learning-Compiler/"/>
<category term="IREE" scheme="https://hjchen2.github.io/tags/IREE/"/>
</entry>
<entry>
<title>IREE编译流程解析(二)</title>
<link href="https://hjchen2.github.io/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B2/"/>
<id>https://hjchen2.github.io/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B2/</id>
<published>2023-01-04T12:20:12.000Z</published>
<updated>2023-02-17T11:31:40.295Z</updated>
<content type="html"><![CDATA[<p>IREE CommonInputConversionPassPipeline主要作用是将IREE::Input dialectlower成IREE::Util、IREE::Flow和IREE::HALdialect,包括以下几个passes。</p><span id="more"></span><ul><li>createIREEImportPublicPass</li></ul><p>将IREE::Input dialect转换成IREE::Util、IREE::Flow和IREE::HALdialect,并转换func的属性和signature中输入输出类型。比如,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">iree_input.global <span class="keyword">private</span> <span class="keyword">mutable</span> @param : tensor<<span class="number">1</span>x2xf32></span><br><span class="line">func.func @<span class="built_in">run</span>(%arg0: tensor<<span class="number">1</span>x2xf32>) {</span><br><span class="line"> %<span class="number">0</span> = iree_input.global.load @param : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> %<span class="number">1</span> = iree_input.tensor.clone %<span class="number">0</span> : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> iree_input.global.store %<span class="number">1</span>, @param : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成(<code>iree_input.global.load</code>-><code>util.global.load</code>,<code>iree_input.global.store</code>-><code>util.global.store</code>,<code>iree_input.tensor.clone</code>-><code>flow.tensor.clone</code>):</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">util.global <span class="keyword">private</span> <span class="keyword">mutable</span> @param : tensor<<span class="number">1</span>x2xf32></span><br><span class="line">func.func @<span class="built_in">run</span>(%arg0: tensor<<span class="number">1</span>x2xf32>) {</span><br><span class="line"> %param = util.global.load @param : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> %<span class="number">0</span> = flow.tensor.clone %param : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> util.global.store %<span class="number">0</span>, @param : tensor<<span class="number">1</span>x2xf32></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure><ul><li>createImportMLProgramPass</li></ul><p>将ml_program dialect转换到IREE::Util dialect。</p><ul><li><p>createSanitizeModuleNamesPass</p><p>将module name中的<code>.</code>替换为<code>_</code>,以符合mliridentifiers的命名规范。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">module</span> @iree.<span class="keyword">module</span> {</span><br><span class="line"> func.func @<span class="built_in">test</span>(%arg0: f32, %arg1: f32) -> f32 {</span><br><span class="line"> %<span class="number">0</span> = arith.addf %arg0, %arg1 : f32</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">0</span> : f32</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">module</span> @iree_module {</span><br><span class="line"> func.func @<span class="built_in">test</span>(%arg0: f32, %arg1: f32) -> f32 {</span><br><span class="line"> %<span class="number">0</span> = arith.addf %arg0, %arg1 : f32</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">0</span> : f32</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li></ul>]]></content>
<summary type="html"><p>IREE CommonInputConversionPassPipeline主要作用是将IREE::Input dialect
lower成IREE::Util、IREE::Flow和IREE::HAL
dialect,包括以下几个passes。</p></summary>
<category term="DL Compiler" scheme="https://hjchen2.github.io/categories/DL-Compiler/"/>
<category term="Deep Learning Compiler" scheme="https://hjchen2.github.io/tags/Deep-Learning-Compiler/"/>
<category term="IREE" scheme="https://hjchen2.github.io/tags/IREE/"/>
</entry>
<entry>
<title>IREE编译流程解析(一)</title>
<link href="https://hjchen2.github.io/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B1/"/>
<id>https://hjchen2.github.io/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B1/</id>
<published>2023-01-04T12:14:12.000Z</published>
<updated>2023-02-21T12:50:30.443Z</updated>
<content type="html"><![CDATA[<p>IREEInputConversionPassPipeline的主要作用是将不同的输入(MHLO、XLA、TorchTensor和TOSA)统一lower成linalg dialect和builtin的arith dialect、scfdialect和tensordialect。下面以MHLO输入为例,列举了InputConversionPassPipeline中各个pass以及它们的主要作用。</p><span id="more"></span><ul><li><p>mhlo::createLegalizeControlFlowPass</p><p>将TF1.0中的控制流原语(http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf)规范化成HLO中的控制流算子。</p></li><li><p>createTopLevelSCFToCFGPass</p><p>将顶层的structured controlflow表示的控制流图转换成更底层基础块的控制流图(CFG)。</p></li><li><p>createMHLOToMHLOPreprocessingPass</p></li><li><p>mlir::createCanonicalizerPass</p></li><li><p>mlir::createShapeToShapeLowering</p><p>将 <code>shape.num_elements</code> 转换成<code>shape.reduce</code>。</p></li><li><p>mlir::createConvertShapeToStandardPass</p><p>将shape dialect lower成arith dialect、scf dialect和tensordialect。比如</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: tensor<<span class="number">1</span>x?xf32>, %arg1: tensor<?xf32>) -> index {</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">0</span> = shape.dim %arg0, %c1 : tensor<<span class="number">1</span>x?xf32>, index -> index</span><br><span class="line"> %<span class="number">1</span> = shape.dim %arg1, %c0 : tensor<?xf32>, index -> index</span><br><span class="line"> %<span class="number">2</span> = shape.add %<span class="number">0</span>, %<span class="number">1</span> : index, index -> index</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">2</span> : index</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: tensor<<span class="number">1</span>x?xf32>, %arg1: tensor<?xf32>) -> index {</span><br><span class="line"> %c1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c0 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %c1_0 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %c1_1 = arith.constant <span class="number">1</span> : index</span><br><span class="line"> %<span class="number">0</span> = tensor.dim %arg0, %c1_1 : tensor<<span class="number">1</span>x?xf32></span><br><span class="line"> %<span class="number">1</span> = tensor.from_elements %c1_0, %<span class="number">0</span> : tensor<<span class="number">2</span>xindex></span><br><span class="line"> %<span class="number">2</span> = tensor.cast %<span class="number">1</span> : tensor<<span class="number">2</span>xindex> to tensor<<span class="number">2</span>xindex></span><br><span class="line"> %<span class="number">3</span> = tensor.dim %arg0, %c1 : tensor<<span class="number">1</span>x?xf32></span><br><span class="line"> %c0_2 = arith.constant <span class="number">0</span> : index</span><br><span class="line"> %<span class="number">4</span> = tensor.dim %arg1, %c0_2 : tensor<?xf32></span><br><span class="line"> %<span class="number">5</span> = tensor.from_elements %<span class="number">4</span> : tensor<<span class="number">1</span>xindex></span><br><span class="line"> %<span class="number">6</span> = tensor.cast %<span class="number">5</span> : tensor<<span class="number">1</span>xindex> to tensor<<span class="number">1</span>xindex></span><br><span class="line"> %<span class="number">7</span> = tensor.dim %arg1, %c0 : tensor<?xf32></span><br><span class="line"> %<span class="number">8</span> = arith.addi %<span class="number">3</span>, %<span class="number">7</span> : index</span><br><span class="line"> <span class="keyword">return</span> %<span class="number">8</span> : index</span><br><span class="line"> }</span><br></pre></td></tr></table></figure></li><li><p>mlir::createCanonicalizerPass</p></li><li><p>mlir::createInlinerPass</p><p>内联calls和callable operations,并删除dead callables。比如:</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: tensor<<span class="number">1</span>xf32>, %arg1: tensor<<span class="number">1</span>xf32>) -> tensor<<span class="number">1</span>xf32> {</span><br><span class="line"> %<span class="number">0</span> = call @<span class="built_in">add</span>(%arg0, %arg1) : (tensor<<span class="number">1</span>xf32>, tensor<<span class="number">1</span>xf32>) -> tensor<<span class="number">1</span>xf32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">0</span> : tensor<<span class="number">1</span>xf32></span><br><span class="line">}</span><br><span class="line">func.func <span class="keyword">private</span> @<span class="built_in">add</span>(%arg0: tensor<<span class="number">1</span>xf32>, %arg1: tensor<<span class="number">1</span>xf32>) -> tensor<<span class="number">1</span>xf32> {</span><br><span class="line"> %<span class="number">0</span> = mhlo.add %arg0, %arg1 : tensor<<span class="number">1</span>xf32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">0</span> : tensor<<span class="number">1</span>xf32></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>私有的add函数被内联之后删除,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: tensor<<span class="number">1</span>xf32>, %arg1: tensor<<span class="number">1</span>xf32>) -> tensor<<span class="number">1</span>xf32> {</span><br><span class="line"> %<span class="number">0</span> = mhlo.add %arg0, %arg1 : tensor<<span class="number">1</span>xf32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">0</span> : tensor<<span class="number">1</span>xf32></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>IREE::Util::createDemoteI64ToI32Pass</p></li><li><p>IREE::Util::createDemoteF64ToF32Pass</p></li><li><p>mlir::createCanonicalizerPass</p></li><li><p>mlir::createCSEPass</p></li><li><p>mhlo::createLegalizeShapeComputationsPass</p><p>把scalar tensor op转换成scalar op + fromElements op。比如</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: f32, %arg1: f32) -> tensor<<span class="number">1</span>xf32> {</span><br><span class="line"> %<span class="number">0</span> = tensor.from_elements %arg0 : tensor<<span class="number">1</span>xf32></span><br><span class="line"> %<span class="number">1</span> = tensor.from_elements %arg1 : tensor<<span class="number">1</span>xf32></span><br><span class="line"> %<span class="number">2</span> = mhlo.add %<span class="number">0</span>, %<span class="number">1</span> : tensor<<span class="number">1</span>xf32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">2</span> : tensor<<span class="number">1</span>xf32></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成:</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: f32, %arg1: f32) -> tensor<<span class="number">1</span>xf32> {</span><br><span class="line"> %<span class="number">0</span> = arith.addf %arg0, %arg1 : f32</span><br><span class="line"> %<span class="number">1</span> = tensor.from_elements %<span class="number">0</span> : tensor<<span class="number">1</span>xf32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">1</span> : tensor<<span class="number">1</span>xf32></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createConvertMHLOToLinalgExtPass</p><p>将<code>mhlo::sort</code>、<code>mhlo.scatter</code>、<code>mhlo.fft</code>、<code>mhlo.reverse</code>、<code>mhlo.topk</code>转换到IREE::LinalgExtdialect,同时将在IREE::LinalgExt dialect区域内部的mhlo op转换成linalgdialect,<code>mhlo.return</code>则转换成<code>iree_linalg_ext.yield</code>。比如,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: tensor<<span class="number">10</span>xf32>) -> tensor<<span class="number">10</span>xf32> {</span><br><span class="line"> %<span class="number">0</span> = <span class="string">"mhlo.sort"</span>(%arg0) ({</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg1: tensor<f32>, %arg2: tensor<f32>):</span><br><span class="line"> %<span class="number">1</span> = mhlo.compare GT, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1></span><br><span class="line"> mhlo.<span class="keyword">return</span> %<span class="number">1</span> : tensor<i1></span><br><span class="line"> }) {dimension = <span class="number">0</span> : i64} : (tensor<<span class="number">10</span>xf32>) -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">0</span> : tensor<<span class="number">10</span>xf32></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: tensor<<span class="number">10</span>xf32>) -> tensor<<span class="number">10</span>xf32> {</span><br><span class="line"> %<span class="number">0</span> = iree_linalg_ext.sort <span class="built_in">dimension</span>(<span class="number">0</span>) <span class="built_in">outs</span>(%arg0 : tensor<<span class="number">10</span>xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg1: f32, %arg2: f32):</span><br><span class="line"> %<span class="number">1</span> = arith.cmpf ogt, %arg1, %arg2 : f32</span><br><span class="line"> iree_linalg_ext.yield %<span class="number">1</span> : i1</span><br><span class="line"> } -> tensor<<span class="number">10</span>xf32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">0</span> : tensor<<span class="number">10</span>xf32></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>createMHLOToLinalgOnTensorsPass</p><p>将外层剩余的mhlo op转换到linalg dialect。比如</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: tensor<<span class="number">1</span>xf32>, %arg1: tensor<<span class="number">1</span>xf32>) -> tensor<<span class="number">1</span>xf32> {</span><br><span class="line"> %<span class="number">0</span> = mhlo.add %arg0, %arg1 : tensor<<span class="number">1</span>xf32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">0</span> : tensor<<span class="number">1</span>xf32></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>转换成,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">func.func @<span class="built_in">test</span>(%arg0: tensor<<span class="number">1</span>xf32>, %arg1: tensor<<span class="number">1</span>xf32>) -> tensor<<span class="number">1</span>xf32> {</span><br><span class="line"> %<span class="number">0</span> = linalg.init_tensor [<span class="number">1</span>] : tensor<<span class="number">1</span>xf32></span><br><span class="line"> %<span class="number">1</span> = linalg.generic {indexing_maps = [<span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>, <span class="built_in">affine_map</span><(d0) -> (d0)>], iterator_types = [<span class="string">"parallel"</span>]} <span class="built_in">ins</span>(%arg0, %arg1 : tensor<<span class="number">1</span>xf32>, tensor<<span class="number">1</span>xf32>) <span class="built_in">outs</span>(%<span class="number">0</span> : tensor<<span class="number">1</span>xf32>) {</span><br><span class="line"> ^<span class="built_in">bb0</span>(%arg2: f32, %arg3: f32, %arg4: f32):</span><br><span class="line"> %<span class="number">2</span> = arith.addf %arg2, %arg3 : f32</span><br><span class="line"> linalg.yield %<span class="number">2</span> : f32</span><br><span class="line"> } -> tensor<<span class="number">1</span>xf32></span><br><span class="line"> <span class="keyword">return</span> %<span class="number">1</span> : tensor<<span class="number">1</span>xf32></span><br><span class="line">}</span><br></pre></td></tr></table></figure></li><li><p>mlir::createReconcileUnrealizedCastsPass</p><p>消除unrealized conversion cast操作。算法过程描述:如果unrealizedconversion cast是dead节点(没有user或所有users也都是unrealizedconversioncast),则直接删除该dead节点;如果是live节点(至少有一个非unrealizedconversioncast的user),则遍历其所有子节点,如果其子节点中所有unrealizedconversion cast的result type与该op的inputtype相同(即不存在真实意义的type cast操作),则将所有遍历到的unrealizedconversioncast都折叠成该op的输入,否则报错<code>live unrealized conversion cast</code>。</p></li><li><p>mlir::createCanonicalizerPass</p></li><li><p>createVerifyCompilerMHLOInputLegality</p><p>验证program是否合法。</p></li></ul>]]></content>
<summary type="html"><p>IREE
InputConversionPassPipeline的主要作用是将不同的输入(MHLO、XLA、Torch
Tensor和TOSA)统一lower成linalg dialect和builtin的arith dialect、scf
dialect和tensor
dialect。下面以MHLO输入为例,列举了InputConversionPassPipeline中各个pass以及它们的主要作用。</p></summary>
<category term="DL Compiler" scheme="https://hjchen2.github.io/categories/DL-Compiler/"/>
<category term="Deep Learning Compiler" scheme="https://hjchen2.github.io/tags/Deep-Learning-Compiler/"/>
<category term="IREE" scheme="https://hjchen2.github.io/tags/IREE/"/>
</entry>
<entry>
<title>IREE编译流程解析</title>
<link href="https://hjchen2.github.io/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B/"/>
<id>https://hjchen2.github.io/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B/</id>
<published>2023-01-04T04:00:04.000Z</published>
<updated>2023-02-24T12:36:50.764Z</updated>
<content type="html"><![CDATA[<p>IREE目前支持将MHLO或XLA、TorchTensor和TOSA作为输入,经过一系列passes编译生成IREE定义的VMbytecode中间产物,其中硬件相关代码会编译成相应的Executable,保存在VMbytecode中供host进行调用,比如CUDA相关的计算代码会被lower成PTX代码,在IREE的runtime中再被CUDA的运行时以JIT的方式编译成可执行的cubinkernel。</p><span id="more"></span><p>IREE编译的入口是IREEVMTransformPassPipeline,IREEVMTransformPassPipeline又被分成InputConversionPassPipeline、CommonInputConversionPassPipeline、ABI::TransformPassPipeline、Flow::FlowTransformPassPipeline、Stream::StreamTransformPassPipeline(仅CUDA后端)、HAL::HALTransformPassPipeline、VM::VMTransformPassPipeline等几个阶段。</p><ul><li><p>InputConversionPassPipeline</p><a href="/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B1/" title="IREE编译流程解析(一)">IREE编译流程解析(一)</a><p>主要作用是将不同的输入(MHLO或XLA、TorchTensor和TOSA)统一lower成linalg dialect和builtin的arith dialect、scfdialect和tensor dialect。</p></li><li><p>CommonInputConversionPassPipeline</p><a href="/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B2/" title="IREE编译流程解析(二)">IREE编译流程解析(二)</a><p>主要作用是将IREE::Input dialectlower成IREE::Util、IREE::Flow和IREE::HAL dialect。</p></li><li><p>ABI::TransformPassPipeline</p><a href="/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B3/" title="IREE编译流程解析(三)">IREE编译流程解析(三)</a><p>主要作用是将外部导入的接口和本module导出到外部的接口参数统一成标准标量类型或<code>hal.buffer_view</code>类型(<code>hal.buffer_view</code>对应tensor)。</p></li><li><p>Flow::FlowTransformPassPipeline</p><a href="/2023/01/04/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B4/" title="IREE编译流程解析(四)">IREE编译流程解析(四)</a><p>主要作用是执行一系列窥孔优化,比如1x1的conv2d转换成matmul、tiling、opfusion等,最终将workload拆分成<code>flow.executable</code>。</p></li><li><p>Stream::StreamTransformPassPipeline</p><a href="/2023/02/13/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B5/" title="IREE编译流程解析(五)">IREE编译流程解析(五)</a><p>主要作用是将program转换到streamdialect,优化变量编码方式,划分调度子图,生成异步调度策略,并实现内存规划策略。</p></li><li><p>HAL::HALTransformPassPipeline</p><a href="/2023/02/24/IREE%E7%BC%96%E8%AF%91%E6%B5%81%E7%A8%8B6/" title="IREE编译流程解析(六)">IREE编译流程解析(六)</a><p>主要作用是进行tiling、vectorization和bufferization等操作,分配计算负载,最终生成targetdevice的代码。比如cuda target的dispatch source code会被递降为NVVMIR。</p></li><li><p>VM::VMTransformPassPipeline</p></li></ul>]]></content>
<summary type="html"><p>IREE目前支持将MHLO或XLA、Torch
Tensor和TOSA作为输入,经过一系列passes编译生成IREE定义的VM
bytecode中间产物,其中硬件相关代码会编译成相应的Executable,保存在VM
bytecode中供host进行调用,比如CUDA相关的计算代码会被lower成PTX代码,在IREE的runtime中再被CUDA的运行时以JIT的方式编译成可执行的cubin
kernel。</p></summary>
<category term="DL Compiler" scheme="https://hjchen2.github.io/categories/DL-Compiler/"/>
<category term="Deep Learning Compiler" scheme="https://hjchen2.github.io/tags/Deep-Learning-Compiler/"/>
<category term="IREE" scheme="https://hjchen2.github.io/tags/IREE/"/>
</entry>
<entry>
<title>如何在XRT框架下添加自定义的后端引擎</title>
<link href="https://hjchen2.github.io/2020/02/25/%E5%A6%82%E4%BD%95%E5%9C%A8XRT%E6%A1%86%E6%9E%B6%E4%B8%8B%E6%B7%BB%E5%8A%A0%E8%87%AA%E5%AE%9A%E4%B9%89%E7%9A%84%E5%90%8E%E7%AB%AF%E5%BC%95%E6%93%8E/"/>
<id>https://hjchen2.github.io/2020/02/25/%E5%A6%82%E4%BD%95%E5%9C%A8XRT%E6%A1%86%E6%9E%B6%E4%B8%8B%E6%B7%BB%E5%8A%A0%E8%87%AA%E5%AE%9A%E4%B9%89%E7%9A%84%E5%90%8E%E7%AB%AF%E5%BC%95%E6%93%8E/</id>
<published>2020-02-25T08:06:18.000Z</published>
<updated>2023-02-07T02:39:00.722Z</updated>
<content type="html"><![CDATA[<p>XRT为不同的后端引擎提供了统一的上层功能和接口抽象,这些功能和接口包括:</p><ul><li>统一的DAG计算图表示</li><li>统一的子图表达、切分和折叠过程</li><li>统一的JIT子图编译接口和缓存机制</li><li>统一的Executable Launch接口</li></ul><p>得益于上层统一的抽象和模块化的设计,后端引擎只需要处理一些差异化的接口,并且这些差异化通常只体现在子图的编译和executablelaunch接口的具体实现上。</p><span id="more"></span><p>我们把XRT的每个子图都看成是一个function,function包含输入和输出参数,以及对应的函数体(DAG表示的计算图),比如下面表示的是只包含一个relu节点的XRT子图,其中node表示计算节点,input和output分别表示子图的输入和输出。</p><figure class="highlight txt"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br></pre></td><td class="code"><pre><span class="line">function {</span><br><span class="line"> input {</span><br><span class="line"> name: "_xrt_entry_0"</span><br><span class="line"> value: "_MyGraph_0_input.0.0_2/out"</span><br><span class="line"> }</span><br><span class="line"> output {</span><br><span class="line"> name: "_xrt_return_0"</span><br><span class="line"> value: "relu-0/y_0"</span><br><span class="line"> }</span><br><span class="line"> node {</span><br><span class="line"> name: "relu-0"</span><br><span class="line"> device_tag: "cuda"</span><br><span class="line"> user_conf {</span><br><span class="line"> op_type_name: "relu"</span><br><span class="line"> input {</span><br><span class="line"> key: "x"</span><br><span class="line"> value {</span><br><span class="line"> s: "_MyGraph_0_input.0.0_2/out"</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> output {</span><br><span class="line"> key: "y"</span><br><span class="line"> value {</span><br><span class="line"> s: "relu-0/y_0"</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>在runtime阶段function首先需要被编译成executable,执行function实际上就是feed对应的输入参数去launch这个编译好的executable,同时得到执行的结果,即function的返回值。</p><p>在XRT框架下每个后端引擎都有一个与之相对应的executable(比如XLA的XlaExecutable和TensorRT的TrtExecutable),和将function编译成对应executable的compiler(比如XLA的XlaGraphCompiler和TensorRT的TrtGraphCompiler),因此添加一个新的后端引擎,通常只需要添加一个对应的executable和compiler。下面以添加一个自定义的后端引擎Toy为例,详细介绍在XRT框架下支持新的后端引擎的具体过程。</p><p>首先在xrt.proto文件中XrtEngine下增加一个Toy引擎字段。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">enum</span> <span class="title class_">XrtEngine</span> {</span><br><span class="line"> DEFAULT = <span class="number">1</span>;</span><br><span class="line"> XLA = <span class="number">2</span>;</span><br><span class="line"> TENSORRT = <span class="number">3</span>;</span><br><span class="line"> TVM = <span class="number">4</span>;</span><br><span class="line"> TOY = <span class="number">5</span>; <span class="comment">// For Toy engine</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>如果Toy引擎针对的硬件不在XrtDevice中,则需要在XrtDevice中增加对应的设备字段。这里我们假设自定义的Toy引擎只支持GPU_CUDA,因此就不需要修改XrtDevice了。</p><p>接下来,与XLA和TensorRT一样,我们在<code>oneflow_xrt/compiler</code>目录下创建一个toy目录,其余所有与Toy引擎相关的代码都将放在该目录下。</p><h2 id="toy-executable">Toy Executable</h2><p>在增加任何一个后端引擎之前,我们都需要仔细考虑该后端引擎所需的最小执行环境,一个最简单的执行环境包括输入输出、中间结果以及执行具体计算逻辑的硬件代码,这个代码可以是通过codegen自动生成的,也可以是手工实现的。</p><p>接下来我们给自定义的Toy引擎增加一个对应的ToyExecutable。在<code>oneflow_xrt/compiler/toy</code>目录下,我们创建文件toy_executable.h和toy_executable.cpp。</p><p>toy_executable.h中定义ToyExecutable,ToyExecutable必须继承自Executable,并实现Run接口。为了尽可能简单,ToyExecutable只包含输出outputs、中间结果tmp_buffers和编排好的函数调用列表func_codes,以及每个函数的输入输出参数对应的buffer序号func_args_。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">ifndef</span> ONEFLOW_XRT_COMPILER_TOY_TOY_EXECUTABLE_H_</span></span><br><span class="line"></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string">"oneflow_xrt/compiler/executable.h"</span></span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string">"oneflow_xrt/compiler/parameter.h"</span></span></span><br><span class="line"></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><vector></span></span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><functional></span></span></span><br><span class="line"></span><br><span class="line"><span class="keyword">namespace</span> oneflow {</span><br><span class="line"><span class="keyword">namespace</span> xrt {</span><br><span class="line"></span><br><span class="line"><span class="keyword">typedef</span> std::function<<span class="type">void</span>(<span class="type">const</span> std::vector<Parameter> &,</span><br><span class="line"> <span class="type">const</span> std::vector<Parameter> &)> FuncCode;</span><br><span class="line"></span><br><span class="line"><span class="keyword">struct</span> <span class="title class_">FuncArgumentIndices</span> {</span><br><span class="line"> std::vector<<span class="type">int</span>> inputs;</span><br><span class="line"> std::vector<<span class="type">int</span>> outputs;</span><br><span class="line">};</span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">ToyExecutable</span> : <span class="keyword">public</span> Executable {</span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="built_in">ToyExecutable</span>(<span class="type">const</span> std::string &name, <span class="type">const</span> <span class="type">int</span> num_inputs,</span><br><span class="line"> <span class="type">const</span> std::vector<Parameter> &outputs,</span><br><span class="line"> <span class="type">const</span> std::vector<Parameter> &temp_buffers,</span><br><span class="line"> <span class="type">const</span> std::vector<FuncCode> &func_codes,</span><br><span class="line"> <span class="type">const</span> std::vector<FuncArgumentIndices> &func_args);</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="type">bool</span> <span class="title">Run</span><span class="params">(<span class="type">const</span> std::vector<Parameter> &inputs,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> ExecutableRunOptions &run_options,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">bool</span> block_until_done = <span class="literal">true</span>)</span> <span class="keyword">override</span></span>;</span><br><span class="line"></span><br><span class="line"> <span class="keyword">private</span>:</span><br><span class="line"> <span class="type">int</span> num_inputs_;</span><br><span class="line"> std::vector<Parameter> outputs_;</span><br><span class="line"> std::vector<Parameter> temp_buffers_;</span><br><span class="line"> std::vector<FuncCode> func_codes_;</span><br><span class="line"> std::vector<FuncArgumentIndices> func_args_;</span><br><span class="line">};</span><br><span class="line"></span><br><span class="line">} <span class="comment">// namespace xrt</span></span><br><span class="line">} <span class="comment">// namespace oneflow</span></span><br><span class="line"></span><br><span class="line"><span class="meta">#<span class="keyword">endif</span> <span class="comment">// ONEFLOW_XRT_COMPILER_TOY_TOY_EXECUTABLE_H_</span></span></span><br></pre></td></tr></table></figure><p>在toy_executable.cpp中实现Run方法,这里我们只是简单的顺序执行编排好的函数func_codes。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string">"oneflow_xrt/compiler/toy/toy_executable.h"</span></span></span><br><span class="line"></span><br><span class="line"><span class="keyword">namespace</span> oneflow {</span><br><span class="line"><span class="keyword">namespace</span> xrt {</span><br><span class="line"></span><br><span class="line">ToyExecutable::<span class="built_in">ToyExecutable</span>(<span class="type">const</span> std::string &name, <span class="type">const</span> <span class="type">int</span> num_inputs,</span><br><span class="line"> <span class="type">const</span> std::vector<Parameter> &outputs,</span><br><span class="line"> <span class="type">const</span> std::vector<Parameter> &temp_buffers,</span><br><span class="line"> <span class="type">const</span> std::vector<FuncCode> &func_codes,</span><br><span class="line"> <span class="type">const</span> std::vector<FuncArgumentIndices> &func_args)</span><br><span class="line"> : <span class="built_in">Executable</span>(name, XrtEngine::TOY),</span><br><span class="line"> <span class="built_in">num_inputs_</span>(num_inputs),</span><br><span class="line"> <span class="built_in">outputs_</span>(outputs),</span><br><span class="line"> <span class="built_in">temp_buffers_</span>(temp_buffers),</span><br><span class="line"> <span class="built_in">func_codes_</span>(func_codes),</span><br><span class="line"> <span class="built_in">func_args_</span>(func_args) {}</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="type">bool</span> <span class="title">ToyExecutable::Run</span><span class="params">(<span class="type">const</span> std::vector<Parameter> &inputs,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> ExecutableRunOptions &run_options,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">bool</span> block_until_done)</span> </span>{</span><br><span class="line"> <span class="keyword">auto</span> PullArgs = [&](<span class="type">const</span> std::vector<<span class="type">int</span>> &indices) {</span><br><span class="line"> std::vector<Parameter> args;</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">int</span> idx : indices) {</span><br><span class="line"> <span class="keyword">if</span> (idx < num_inputs_) {</span><br><span class="line"> args.<span class="built_in">push_back</span>(inputs[idx]);</span><br><span class="line"> } <span class="keyword">else</span> <span class="keyword">if</span> (idx < num_inputs_ + outputs_.<span class="built_in">size</span>()) {</span><br><span class="line"> args.<span class="built_in">push_back</span>(outputs_[idx - num_inputs_]);</span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> idx -= (num_inputs_ + outputs_.<span class="built_in">size</span>());</span><br><span class="line"> <span class="built_in">CHECK_GE</span>(idx, <span class="number">0</span>);</span><br><span class="line"> <span class="built_in">CHECK_LT</span>(idx, temp_buffers_.<span class="built_in">size</span>());</span><br><span class="line"> args.<span class="built_in">push_back</span>(temp_buffers_[idx]);</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span> std::<span class="built_in">move</span>(args);</span><br><span class="line"> };</span><br><span class="line"></span><br><span class="line"> <span class="built_in">CHECK_EQ</span>(inputs.<span class="built_in">size</span>(), num_inputs_);</span><br><span class="line"></span><br><span class="line"> <span class="keyword">for</span> (<span class="type">int</span> i = <span class="number">0</span>; i < func_codes_.<span class="built_in">size</span>(); ++i) {</span><br><span class="line"> <span class="keyword">auto</span> in_args = <span class="built_in">PullArgs</span>(func_args_[i].inputs);</span><br><span class="line"> <span class="keyword">auto</span> out_args = <span class="built_in">PullArgs</span>(func_args_[i].outputs);</span><br><span class="line"> func_codes_[i](in_args, out_args);</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Synchronize stream if block_until_done</span></span><br><span class="line"> <span class="keyword">if</span> (block_until_done) {</span><br><span class="line"> <span class="comment">// TODO()</span></span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="comment">// All return params are the results of the executable</span></span><br><span class="line"> <span class="keyword">this</span>->results_ = run_options.return_params;</span><br><span class="line"> <span class="keyword">return</span> <span class="literal">true</span> <span class="comment">/*running status*/</span>;</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line">} <span class="comment">// namespace xrt</span></span><br><span class="line">} <span class="comment">// namespace oneflow</span></span><br></pre></td></tr></table></figure><p>目前为止我们已经完成了一个最简单的运行时executable,这个executable甚至有点类似其他框架中提供的最简单的图执行器(graphexecutor)。接下来我们要介绍如何将一个XRT的子图编译成上面的ToyExecutable。</p><h2 id="toy-compiler">Toy Compiler</h2><p>每个后端引擎都对应一个compiler,当我们希望使用某个后端引擎来执行一个XRT子图时,就需要有一个对应的compiler将该子图编译成后端引擎对应的executable。Compiler通常都非常注重编译产物的执行性能,而性能以外的关切点也导致了不同的技术路线,比如对算法通用性、跨平台有高度关切的TVM和XLA采用了LLVM传统编译器的路线,而对于过分看重性能但硬件平台单一的TensorRT更多的则是采用手工优化和tuning相结合的策略。不过这两种技术路线并不是完全对立的,也是在不断地相互借鉴和融合。</p><p>在XRT中,所有这些技术方案都是可以被兼容的,你可以根据实际情况自由切换,你也可以把XRT当成实验场所,实现一个自定义的compiler,并在同一套框架下对比不同compiler、不同技术方案的优劣。</p><p>回到本文的主题,我们现在需要实现一个ToyExecutable对应的compiler,我们也把该compiler叫做ToyGraphCompiler。</p><p>首先在<code>oneflow_xrt/compiler/toy</code>目录下新建两个文件toy_graph_compiler.h和toy_graph_compiler.cpp。在toy_graph_compiler.h文件中定义类ToyGraphCompiler,ToyGraphCompiler必须继承自类GraphCompiler::Impl,并实现对应的Compile接口。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">ToyGraphCompiler</span> : <span class="keyword">public</span> GraphCompiler::Impl {</span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="function"><span class="keyword">explicit</span> <span class="title">ToyGraphCompiler</span><span class="params">(<span class="type">const</span> std::string &name)</span></span></span><br><span class="line"><span class="function"> : GraphCompiler::Impl(name) {</span>}</span><br><span class="line"></span><br><span class="line"> <span class="keyword">virtual</span> ~<span class="built_in">ToyGraphCompiler</span>() = <span class="keyword">default</span>;</span><br><span class="line"></span><br><span class="line"> <span class="function">std::shared_ptr<Executable> <span class="title">Compile</span><span class="params">(</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> XrtGraph *graph,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::vector<Parameter> &entry_params,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::vector<Parameter> &return_params,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::vector<InputOutputAlias> &aliases)</span> <span class="keyword">override</span></span>;</span><br><span class="line">};</span><br></pre></td></tr></table></figure><p>在toy_graph_compiler.cpp中实现Compile接口,并注册一个新的graphcompiler。在动手实现该接口之前,有必要先解释一下该接口的参数列表,graph表示的是function子图,entry_params表示子图的输入,return_params表示子图的输出,aliases通常在包含模型更新操作时会用到,表明输出和输入是一对别名关系。被alias的输入将生命期延长到了整个子图,并且与对应的输出共享内存,因此也就间接实现了inplace计算的目的。</p><p>我们按拓扑顺序遍历子图中的每个节点(或op),依次将节点编译成具体的执行代码,并在合适的位置插入临时buffer。为了方便处理不同类型的op,我们在下面的代码中引入了ToyOpContext和ToyOpKernel的概念。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// Register a new graph compiler for TOY engine.</span></span><br><span class="line"><span class="built_in">REGISTER_GRAPH_COMPILER</span>(XrtEngine::TOY, ToyGraphCompiler);</span><br><span class="line"></span><br><span class="line"><span class="comment">// Realize Compile interface.</span></span><br><span class="line"><span class="function">std::shared_ptr<Executable> <span class="title">ToyGraphCompiler::Compile</span><span class="params">(</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> XrtGraph *graph,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::vector<Parameter> &entry_params,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::vector<Parameter> &return_params,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::vector<InputOutputAlias> &aliases)</span> </span>{</span><br><span class="line"> std::vector<Parameter> temp_buffers;</span><br><span class="line"> std::vector<FuncCode> func_codes;</span><br><span class="line"> std::vector<FuncArgumentIndices> func_args;</span><br><span class="line"></span><br><span class="line"> std::unordered_map<std::string, <span class="type">int</span>> indices;</span><br><span class="line"> std::unordered_map<std::string, Parameter> all_params;</span><br><span class="line"> <span class="keyword">for</span> (<span class="keyword">auto</span> param : entry_params) {</span><br><span class="line"> indices.<span class="built_in">emplace</span>(param.<span class="built_in">name</span>(), indices.<span class="built_in">size</span>());</span><br><span class="line"> all_params[param.<span class="built_in">name</span>()] = param;</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">for</span> (<span class="keyword">auto</span> param : return_params) {</span><br><span class="line"> indices.<span class="built_in">emplace</span>(param.<span class="built_in">name</span>(), indices.<span class="built_in">size</span>());</span><br><span class="line"> all_params[param.<span class="built_in">name</span>()] = param;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> algorithm::<span class="built_in">TopologyVisit</span>(*graph, [&](<span class="type">const</span> XrtNode *node) {</span><br><span class="line"> <span class="keyword">if</span> (node-><span class="built_in">IsNoOpNode</span>()) {</span><br><span class="line"> <span class="comment">// NoOp node is not computation node, so skip it</span></span><br><span class="line"> <span class="keyword">return</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> ToyOpContext <span class="built_in">op_context</span>(node, all_params);</span><br><span class="line"> <span class="keyword">auto</span> op_kernel = <span class="built_in">BuildToyOpKernel</span>(node-><span class="built_in">type</span>());</span><br><span class="line"> op_kernel-><span class="built_in">Compile</span>(&op_context);</span><br><span class="line"></span><br><span class="line"> func_codes.<span class="built_in">push_back</span>(op_context.func_code_);</span><br><span class="line"></span><br><span class="line"> <span class="type">const</span> <span class="keyword">auto</span> &buffers = op_context.tmp_buffers_;</span><br><span class="line"> <span class="keyword">for</span> (<span class="keyword">auto</span> it = buffers.<span class="built_in">begin</span>(); it != buffers.<span class="built_in">end</span>(); ++it) {</span><br><span class="line"> all_params[it->first] = it->second;</span><br><span class="line"> temp_buffers.<span class="built_in">push_back</span>(it->second);</span><br><span class="line"> indices.<span class="built_in">emplace</span>(it->first, indices.<span class="built_in">size</span>());</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Finalize argument indices for each function</span></span><br><span class="line"> FuncArgumentIndices arg_indices;</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">const</span> <span class="keyword">auto</span> &arg : op_context.input_args_) {</span><br><span class="line"> arg_indices.inputs.<span class="built_in">push_back</span>(indices.<span class="built_in">at</span>(arg));</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">const</span> <span class="keyword">auto</span> &arg : op_context.output_args_) {</span><br><span class="line"> arg_indices.outputs.<span class="built_in">push_back</span>(indices.<span class="built_in">at</span>(arg));</span><br><span class="line"> }</span><br><span class="line"> func_args.<span class="built_in">push_back</span>(std::<span class="built_in">move</span>(arg_indices));</span><br><span class="line"> });</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> std::<span class="built_in">make_shared</span><ToyExecutable>(<span class="keyword">this</span>->name_, entry_params.<span class="built_in">size</span>(),</span><br><span class="line"> return_params, temp_buffers,</span><br><span class="line"> func_codes, func_args);</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>ToyOpContext临时存储编译需要的元信息和编译结果,为ToyOpKernel提供必要的接口,ToyOpKernel则根据op类型完成单个op的编译过程。上述代码中我们实现了一个将XRT子图编译成ToyExecutable的最简单的graphcompiler,下面我们将以ReLUop为例,介绍ToyOpContext和ToyOpKernel是如何对op进行编译的。</p><h2 id="toy-kernels">Toy Kernels</h2><p>我们回过头再仔细研究一下ToyGraphCompiler的Compile实现,ToyOpContext接受两个输入,node和当前所有已经创建过的parameters,经过OpKernel编译后输出函数代码(func_code_)、中间buffer(tmp_buffers_),以及函数代码输入和输出对应的parameternames。因此在这个例子中,ToyOpContext被设计成如下形式:</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">ToyOpContext</span> {</span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="built_in">ToyOpContext</span>(<span class="type">const</span> XrtNode *node,</span><br><span class="line"> <span class="type">const</span> std::unordered_map<std::string, Parameter> &all_params)</span><br><span class="line"> : <span class="built_in">node_</span>(node), <span class="built_in">all_params_</span>(all_params) {}</span><br><span class="line"></span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="type">const</span> XrtNode *node_;</span><br><span class="line"> <span class="type">const</span> std::unordered_map<std::string, Parameter> &all_params_;</span><br><span class="line"></span><br><span class="line"> std::function<<span class="type">void</span>(<span class="type">const</span> std::vector<Parameter>&,</span><br><span class="line"> <span class="type">const</span> std::vector<Parameter>&)> func_code_;</span><br><span class="line"> std::vector<std::string> input_args_;</span><br><span class="line"> std::vector<std::string> output_args_;</span><br><span class="line"> std::unordered_map<std::string, Parameter> tmp_buffers_;</span><br><span class="line">};</span><br></pre></td></tr></table></figure><p>对于ToyOpKernel,为了处理不同类型的op,我们采用工厂注册模式,并且这种模式还有另一个用处,就是在XRT划分子图时可以用来判断该引擎是否支持某个类型的op。XRT已经将kernel注册接口封装成了一个辅助类OpKernelRegistrar,但同时也要求ToyOpKernel必须继承基类OpKernel。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">ToyOpKernel</span> : <span class="keyword">public</span> OpKernel<ToyOpContext> {</span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="function"><span class="keyword">virtual</span> <span class="type">void</span> <span class="title">Compile</span><span class="params">(ToyOpContext *ctx)</span> </span>= <span class="number">0</span>;</span><br><span class="line">};</span><br></pre></td></tr></table></figure><p>使用OpKernelRegistrar定义一个用来注册ToyOpKernel的宏。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">define</span> REGISTER_TOY_OP_KERNEL(OpName, KernelType) \</span></span><br><span class="line"><span class="meta"> static auto _toy_op_kernel_##OpName##_ __attribute__((unused)) = \</span></span><br><span class="line"><span class="meta"> OpKernelRegistrar(#OpName) \</span></span><br><span class="line"><span class="meta"> .SetEngine(XrtEngine::TOY) \</span></span><br><span class="line"><span class="meta"> .SetDevice({XrtDevice::GPU_CUDA}) \</span></span><br><span class="line"><span class="meta"> .SetFactory([]() -> OpKernelBase * { \</span></span><br><span class="line"><span class="meta"> return new KernelType; \</span></span><br><span class="line"><span class="meta"> })</span></span><br></pre></td></tr></table></figure><p>最后我们实现一个Relu的OpKernel,填充ToyOpContext的func_code_、tmp_buffers_以及输入输出arguments。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">ComputeRelu</span><span class="params">(<span class="type">const</span> Parameter &input, <span class="type">const</span> Parameter &output)</span> </span>{</span><br><span class="line"> <span class="comment">//TODO(hjchen2)</span></span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">ToyReluOpKernel</span> : <span class="keyword">public</span> ToyOpKernel {</span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">Compile</span><span class="params">(ToyOpContext *ctx)</span> <span class="keyword">override</span> </span>{</span><br><span class="line"> ctx->func_code_ = [](<span class="type">const</span> std::vector<Parameter> &inputs,</span><br><span class="line"> <span class="type">const</span> std::vector<Parameter> &outputs) {</span><br><span class="line"> <span class="built_in">CHECK_EQ</span>(inputs.<span class="built_in">size</span>(), <span class="number">1</span>);</span><br><span class="line"> <span class="built_in">CHECK_EQ</span>(outputs.<span class="built_in">size</span>(), <span class="number">1</span>);</span><br><span class="line"> <span class="built_in">ComputeRelu</span>(inputs[<span class="number">0</span>], outputs[<span class="number">0</span>]);</span><br><span class="line"> };</span><br><span class="line"></span><br><span class="line"> <span class="keyword">for</span> (<span class="type">const</span> XrtEdge *edge : ctx->node_-><span class="built_in">in_edges</span>()) {</span><br><span class="line"> <span class="type">const</span> <span class="keyword">auto</span> &name = edge-><span class="built_in">argument</span>().<span class="built_in">name</span>();</span><br><span class="line"> <span class="built_in">CHECK_GT</span>(ctx->all_params_.<span class="built_in">count</span>(name), <span class="number">0</span>);</span><br><span class="line"> <span class="comment">// TODO(): Filter duplicate input</span></span><br><span class="line"> ctx->input_args_.<span class="built_in">push_back</span>(name);</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="keyword">for</span> (<span class="type">const</span> XrtEdge *edge : ctx->node_-><span class="built_in">out_edges</span>()) {</span><br><span class="line"> <span class="type">const</span> <span class="keyword">auto</span> &name = edge-><span class="built_in">argument</span>().<span class="built_in">name</span>();</span><br><span class="line"> <span class="comment">// TODO(): Filter duplicate output</span></span><br><span class="line"> ctx->output_args_.<span class="built_in">push_back</span>(name);</span><br><span class="line"> <span class="keyword">if</span> (ctx->all_params_.<span class="built_in">count</span>(name) == <span class="number">0</span> &&</span><br><span class="line"> ctx->tmp_buffers_.<span class="built_in">count</span>(name) == <span class="number">0</span>) {</span><br><span class="line"> <span class="keyword">auto</span> param = <span class="built_in">CreateParameter</span>(name <span class="comment">/*argument name*/</span>,</span><br><span class="line"> edge-><span class="built_in">argument</span>().<span class="built_in">shape</span>(),</span><br><span class="line"> edge-><span class="built_in">argument</span>().<span class="built_in">data_type</span>());</span><br><span class="line"> ctx->tmp_buffers_[name] = std::<span class="built_in">move</span>(param);</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line">};</span><br></pre></td></tr></table></figure><p>最后将ToyReluOpKernel注册到Toy引擎对应的OpKernel工厂下。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="built_in">REGISTER_TOY_OP_KERNEL</span>(relu, ToyReluOpKernel)</span><br><span class="line"> .<span class="built_in">EnableTrainPhase</span>()</span><br><span class="line"> .<span class="built_in">Finalize</span>();</span><br></pre></td></tr></table></figure><p>EnableTrainPhase表示该op支持训练,OpKernelRegistrar也提供了其他一些接口,比如设置支持的device列表,mutablevariables(inplace更新)和是否是model update op(model updateop会影响子图划分)。</p><h2 id="cmake编译">CMake编译</h2><p>在CMakeList.txt中添加一个BUILD_TOY的选项,并在oneflow_xrt/CMakeLists.txt中添加如下toy引擎模块的编译代码,</p><figure class="highlight cmake"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span>(BUILD_TOY)</span><br><span class="line"> <span class="keyword">file</span>(GLOB_RECURSE XRT_TOY_SRCS compiler/toy/*.cpp)</span><br><span class="line"> <span class="keyword">add_library</span>(oneflow_xrt_toy <span class="variable">${XRT_TOY_SRCS}</span>)</span><br><span class="line"> <span class="keyword">add_dependencies</span>(</span><br><span class="line"> oneflow_xrt_toy</span><br><span class="line"> <span class="variable">${XRT_THIRD_PARTY_LIBRARIES}</span>)</span><br><span class="line"> <span class="keyword">target_link_libraries</span>(</span><br><span class="line"> oneflow_xrt_toy</span><br><span class="line"> oneflow_xrt</span><br><span class="line"> <span class="variable">${XRT_THIRD_PARTY_LIBRARIES}</span>)</span><br><span class="line"> <span class="keyword">target_include_directories</span>(</span><br><span class="line"> oneflow_xrt_toy PRIVATE <span class="variable">${ONEFLOW_INCLUDE_DIR}</span>)</span><br><span class="line"><span class="keyword">endif</span>()</span><br></pre></td></tr></table></figure><p>之后在oneflow_xrt/python目录中添加导出Python模块的代码toy_stub.cpp,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><pybind11/pybind11.h></span></span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><pybind11/stl.h></span></span></span><br><span class="line"></span><br><span class="line"><span class="built_in">PYBIND11_MODULE</span>(_oneflow_xrt_toy_internal, m) {}</span><br></pre></td></tr></table></figure><p>并在oneflow_xrt/python/CMakeLists.txt中增加如下代码,</p><figure class="highlight cmake"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span>(BUILD_TOY)</span><br><span class="line"> oneflow_xrt_add_stub(oneflow_xrt_toy toy_stub.cpp)</span><br><span class="line"><span class="keyword">endif</span>()</span><br></pre></td></tr></table></figure><h2 id="编译和安装python-wheel包">编译和安装Python wheel包</h2><p>修改setup.py文件,新增一个toyextension的编译,并在build_ext函数中开启BUILD_TOY选项,</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">setup_extension(</span><br><span class="line"> <span class="string">"oneflow_xrt_toy"</span>,</span><br><span class="line"> cmake_args=[<span class="string">"-DBUILD_TOY=ON"</span>],</span><br><span class="line"> description=(<span class="string">"oneflow_xrt's toy extension"</span>),</span><br><span class="line">)</span><br></pre></td></tr></table></figure><p>执行命令<code>python3 setup.py install</code>完成wheel包的编译和安装,最后执行如下代码测试添加的toy引擎是否可以正常执行,</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> oneflow <span class="keyword">as</span> flow</span><br><span class="line"><span class="keyword">import</span> oneflow_xrt <span class="keyword">as</span> flowrt</span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">ReluGraph</span>(flow.nn.Graph):</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line"> <span class="built_in">super</span>().__init__()</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">build</span>(<span class="params">self, <span class="built_in">input</span></span>):</span><br><span class="line"> <span class="keyword">return</span> flow.nn.functional.relu(<span class="built_in">input</span>)</span><br><span class="line"></span><br><span class="line">m = flowrt.XRTModule(ReluGraph(), engine=<span class="string">"toy"</span>)</span><br><span class="line">x = flow.randn(<span class="number">2</span>, <span class="number">3</span>, device=<span class="string">"cuda"</span>)</span><br><span class="line"><span class="built_in">print</span>(m(x))</span><br></pre></td></tr></table></figure>]]></content>
<summary type="html"><p>XRT为不同的后端引擎提供了统一的上层功能和接口抽象,这些功能和接口包括:</p>
<ul>
<li>统一的DAG计算图表示</li>
<li>统一的子图表达、切分和折叠过程</li>
<li>统一的JIT子图编译接口和缓存机制</li>
<li>统一的Executable Launch接口</li>
</ul>
<p>得益于上层统一的抽象和模块化的设计,后端引擎只需要处理一些差异化的接口,并且这些差异化通常只体现在子图的编译和executable
launch接口的具体实现上。</p></summary>
<category term="XRT" scheme="https://hjchen2.github.io/categories/XRT/"/>
<category term="XRT" scheme="https://hjchen2.github.io/tags/XRT/"/>
<category term="Compiler" scheme="https://hjchen2.github.io/tags/Compiler/"/>
<category term="TensorFlow XLA" scheme="https://hjchen2.github.io/tags/TensorFlow-XLA/"/>
<category term="TensorRT" scheme="https://hjchen2.github.io/tags/TensorRT/"/>
</entry>
<entry>
<title>TVM PackedFunc实现机制</title>
<link href="https://hjchen2.github.io/2020/01/10/TVM-PackedFunc%E5%AE%9E%E7%8E%B0%E6%9C%BA%E5%88%B6/"/>
<id>https://hjchen2.github.io/2020/01/10/TVM-PackedFunc%E5%AE%9E%E7%8E%B0%E6%9C%BA%E5%88%B6/</id>
<published>2020-01-10T04:24:08.000Z</published>
<updated>2023-02-07T02:41:00.291Z</updated>
<content type="html"><![CDATA[<h2 id="tvm-packedfunc实现">TVM PackedFunc实现</h2><p>为了便于Python和C++混合编程,TVM使用了统一的PackedFunc机制。PackedFunc可以将C++中的各类函数打包成统一的函数接口,并自动导出到Python模块中进行调用,并且也支持从Python中注册一个函数,并伪装成PackedFunc在C++和Python中调用。</p><span id="more"></span><p><img src="https://github.com/hjchen2/personal/blob/master/blog/tvm/屏幕快照%202020-01-10%2010.55.45.png?raw=true" style="zoom:36%;" /></p><h3 id="预备知识">预备知识</h3><h4 id="python-ctypes混合编程">Python ctypes混合编程</h4><p>ctypes是Python自带的跨语言函数调用库,ctypes提供了简单的C数据类型,可以将C/C++动态库中的函数包装成Python函数进行调用。</p><ul><li><p>导出C++函数</p><p>首先在C++中定义一个全局函数,并编译生成C++动态库。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// test.h</span></span><br><span class="line"><span class="keyword">extern</span> <span class="string">"C"</span> {</span><br><span class="line"><span class="function"><span class="type">int</span> <span class="title">add</span><span class="params">(<span class="type">int</span> a, <span class="type">int</span> b)</span></span>;</span><br><span class="line">}</span><br></pre></td></tr></table></figure><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// test.cc</span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string">"test.h"</span></span></span><br><span class="line"><span class="function"><span class="type">int</span> <span class="title">add</span><span class="params">(<span class="type">int</span> a, <span class="type">int</span> b)</span> </span>{</span><br><span class="line"> <span class="keyword">return</span> a + b;</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>用ctypes模块在Python中加载生成的动态库(test.so),并调用C++中的函数。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> ctypes</span><br><span class="line"></span><br><span class="line"><span class="comment"># Load shared library</span></span><br><span class="line">_LIB = ctypes.CDLL(<span class="string">"./test.so"</span>, ctypes.RTLD_GLOBAL)</span><br><span class="line"></span><br><span class="line">a = ctypes.c_int(<span class="number">1</span>)</span><br><span class="line">b = ctypes.c_int(<span class="number">2</span>)</span><br><span class="line"><span class="comment"># Call C func in Python</span></span><br><span class="line"><span class="built_in">print</span>(_LIB.add(a, b))</span><br><span class="line"><span class="comment"># Or</span></span><br><span class="line"><span class="built_in">print</span>(_LIB.add(<span class="number">1</span>, <span class="number">2</span>))</span><br></pre></td></tr></table></figure></li><li><p>传递Python函数到C++</p><p>ctypes也支持将Python函数转换成C类型的函数,并在C/C++中进行调用。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">add</span>(<span class="params">a, b</span>):</span><br><span class="line"> <span class="keyword">return</span> a + b</span><br></pre></td></tr></table></figure><p>Pythonadd有两个参数a和b,返回值类型与a和b的类型一致。在C++中可以为Pythonadd定义一个函数原型 int(int, int)。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">extern</span> <span class="string">"C"</span> {</span><br><span class="line"><span class="function"><span class="keyword">typedef</span> <span class="title">int</span> <span class="params">(*PyCFunc)</span><span class="params">(<span class="type">int</span>, <span class="type">int</span>)</span></span>;</span><br><span class="line"><span class="function"><span class="type">int</span> <span class="title">call_py_func</span><span class="params">(PyCFunc f, <span class="type">int</span> a, <span class="type">int</span> b)</span></span>;</span><br><span class="line">}</span><br></pre></td></tr></table></figure><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string">"test.h"</span></span></span><br><span class="line"><span class="function"><span class="type">int</span> <span class="title">call_py_func</span><span class="params">(PyCFunc f, <span class="type">int</span> a, <span class="type">int</span> b)</span> </span>{</span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">f</span>(a, b);</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>使用ctypes将Python函数转换成C function,传入C++中进行调用。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> ctypes</span><br><span class="line"></span><br><span class="line">cfunc = ctypes.CFUNCTYPE(</span><br><span class="line"> ctypes.c_int, <span class="comment"># return type</span></span><br><span class="line"> ctypes.c_int, <span class="comment"># arg0 type</span></span><br><span class="line"> ctypes.c_int <span class="comment"># arg1 type</span></span><br><span class="line"> )</span><br><span class="line"></span><br><span class="line">f = cfunc(add)</span><br><span class="line"><span class="comment"># CFUNCTYPE is callable in Python</span></span><br><span class="line"><span class="built_in">print</span>(f(<span class="number">5</span>, <span class="number">1</span>))</span><br><span class="line"></span><br><span class="line"><span class="comment"># Call Python func in C</span></span><br><span class="line"><span class="built_in">print</span>(_LIB.call_py_func(f, <span class="number">5</span>, <span class="number">1</span>))</span><br></pre></td></tr></table></figure></li></ul><h3 id="packedfunc实现">PackedFunc实现</h3><h4 id="packedfunc定义">PackedFunc定义</h4><p>ctypes可以很方便的将C/C++中的函数导出到Python,调用时直接传入对应的参数即可,但如果需要将Python函数导入到C/C++,则需要在C/C++中提前定义好对应的函数原型(比如上面的PyCFunc),并提供对应函数的调用入口(call_py_func)。为了支持更加灵活的函数定义,TVM将不同类型的函数包装成统一的函数原型。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="built_in">void</span>(TVMArgs args, TVMRetValue *rv);</span><br></pre></td></tr></table></figure><p>统一的函数原型被封装成PackedFunc对象,提供通用的调用接口,直接与调用者进行交互。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">PackedFunc</span> {</span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="keyword">using</span> FType = std::function<<span class="built_in">void</span> (TVMArgs args, TVMRetValue* rv)>;</span><br><span class="line"> <span class="function"><span class="keyword">template</span><<span class="keyword">typename</span>... Args></span></span><br><span class="line"><span class="function"> <span class="keyword">inline</span> TVMRetValue <span class="title">operator</span><span class="params">()</span><span class="params">(Args&& ...args)</span> <span class="type">const</span></span>;</span><br><span class="line"> <span class="function"><span class="keyword">inline</span> <span class="type">void</span> <span class="title">CallPacked</span><span class="params">(TVMArgs args, TVMRetValue* rv)</span> <span class="type">const</span></span>;</span><br><span class="line"></span><br><span class="line"> <span class="keyword">private</span>:</span><br><span class="line"> <span class="comment">/*! \brief internal container of packed function */</span></span><br><span class="line"> FType body_;</span><br><span class="line">};</span><br></pre></td></tr></table></figure><p>当获得一个PackedFunc对象时,我们就可以像调用普通函数一样调用PackedFunc打包的函数。比如:</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">PackedFunc f;</span><br><span class="line"><span class="comment">// f(1, 2)首先会自动将参数1,2打包成TVMArgs,接着调用CallPacked,CallPacked最终的执行体是body_</span></span><br><span class="line">TVMRetValue ret = <span class="built_in">f</span>(<span class="number">1</span>, <span class="number">2</span>);</span><br></pre></td></tr></table></figure><h4 id="函数打包">函数打包</h4><p>TVM支持对各类函数进行打包,包括一般的函数、类的成员函数以及lamda表达式。</p><ul><li><p>函数原型萃取</p><p>萃取函数原型是为了得到函数的参数和返回值类型。TVM中使用decltype和模版结构体function_signature来实现。</p><p>比如定义一个简单的C函数,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">int</span> <span class="title">add</span><span class="params">(<span class="type">int</span> a, <span class="type">int</span> b)</span> </span>{</span><br><span class="line"> <span class="keyword">return</span> a + b;</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>接下来就可以使用如下的代码来萃取add的函数原型,</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">template</span> <<span class="keyword">typename</span> R, <span class="keyword">typename</span> ...Args></span><br><span class="line"><span class="keyword">struct</span> <span class="title class_">function_signature</span><<span class="built_in">R</span>(Args...)> {</span><br><span class="line"> <span class="keyword">using</span> FType = <span class="built_in">R</span>(Args...);</span><br><span class="line">};</span><br><span class="line"></span><br><span class="line"><span class="comment">// 萃取add的函数原型</span></span><br><span class="line"><span class="keyword">using</span> FType = function_signature<<span class="keyword">decltype</span>(add)>::FType;</span><br></pre></td></tr></table></figure><p>此外只需要特化function_signature就可以支持函数指针和lambda表达式。注意:TVMfunction_signature不支持普通成员函数的类型萃取,因此TVM需要借助一个辅助function_signature_helper来对lambda表达式类型进行萃取,而我们这里的function_signature支持普通成员函数,因此lambda表达式类型萃取可以通过递归的function_signature来实现。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// 普通函数指针</span></span><br><span class="line"><span class="keyword">template</span> <<span class="keyword">typename</span> R, <span class="keyword">typename</span> ...Args></span><br><span class="line"><span class="keyword">struct</span> <span class="title class_">function_signature</span><<span class="built_in">R</span>(*)(Args...)> {</span><br><span class="line"> <span class="keyword">using</span> FType = <span class="built_in">R</span>(Args...);</span><br><span class="line">};</span><br><span class="line"></span><br><span class="line"><span class="comment">// 非const类的成员函数指针</span></span><br><span class="line"><span class="keyword">template</span> <<span class="keyword">typename</span> T, <span class="keyword">typename</span> R, <span class="keyword">typename</span> ...Args></span><br><span class="line"> <span class="keyword">struct</span> <span class="title class_">function_signature</span><<span class="built_in">R</span>(T::*)(Args...)> {</span><br><span class="line"> <span class="keyword">using</span> FType = <span class="built_in">R</span>(Args...);</span><br><span class="line">};</span><br><span class="line"></span><br><span class="line"><span class="comment">// const类的成员函数指针</span></span><br><span class="line"><span class="keyword">template</span> <<span class="keyword">typename</span> T, <span class="keyword">typename</span> R, <span class="keyword">typename</span> ...Args></span><br><span class="line"> <span class="keyword">struct</span> <span class="title class_">function_signature</span><<span class="built_in">R</span>(T::*)(Args...) <span class="type">const</span>> {</span><br><span class="line"> <span class="keyword">using</span> FType = <span class="built_in">R</span>(Args...);</span><br><span class="line">};</span><br><span class="line"></span><br><span class="line"><span class="comment">// lambda表达式</span></span><br><span class="line"><span class="keyword">template</span><<span class="keyword">typename</span> T></span><br><span class="line"><span class="keyword">struct</span> <span class="title class_">function_signature</span> {</span><br><span class="line"> <span class="keyword">using</span> FType = <span class="keyword">typename</span> function_signature<<span class="keyword">decltype</span>(&T::<span class="built_in">operator</span>())>::FType;</span><br><span class="line">};</span><br></pre></td></tr></table></figure></li><li><p>函数打包</p><p>一旦萃取到了函数原型,TVM就利用TypedPackedFunc对普通函数或lambda表达式进行打包。TypedPackedFunc只支持对R(Args...)类型的函数打包,所以如果被打包的函数是一个函数指针,则需要创建一个lambda表达式,转换成R(Args...)类型之后再用TypedPackedFunc对创建的lambda表达式进行打包。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">template</span><<span class="keyword">typename</span> R, <span class="keyword">typename</span> ...Args></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">TypedPackedFunc</span><<span class="built_in">R</span>(Args...)> {</span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="keyword">using</span> TSelf = TypedPackedFunc<<span class="built_in">R</span>(Args...)>;</span><br><span class="line"> <span class="keyword">template</span><<span class="keyword">typename</span> FLambda,</span><br><span class="line"> <span class="keyword">typename</span> = <span class="keyword">typename</span> std::enable_if<</span><br><span class="line"> std::is_convertible<FLambda,</span><br><span class="line"> std::function<<span class="built_in">R</span>(Args...)></span><br><span class="line"> >::value>::type></span><br><span class="line"> <span class="built_in">TypedPackedFunc</span>(<span class="type">const</span> FLambda& typed_lambda) { <span class="comment">// NOLINT(*)</span></span><br><span class="line"> <span class="keyword">this</span>-><span class="built_in">AssignTypedLambda</span>(typed_lambda);</span><br><span class="line"> }</span><br><span class="line"> ...</span><br><span class="line"> <span class="keyword">private</span>:</span><br><span class="line"> ...</span><br><span class="line"> PackedFunc packed_;</span><br><span class="line">};</span><br></pre></td></tr></table></figure><p>当被打包的函数用来实例化TypedPackedFunc对象时,会立刻调用AssignTypedLambda将被打包的函数打包成PackedFunc。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">template</span><<span class="keyword">typename</span> R, <span class="keyword">typename</span> ...Args></span><br><span class="line"><span class="keyword">template</span><<span class="keyword">typename</span> FType></span><br><span class="line"><span class="keyword">inline</span> <span class="type">void</span> TypedPackedFunc<<span class="built_in">R</span>(Args...)>::<span class="built_in">AssignTypedLambda</span>(FType flambda) {</span><br><span class="line"> packed_ = <span class="built_in">PackedFunc</span>([flambda](<span class="type">const</span> TVMArgs& args, TVMRetValue* rv) {</span><br><span class="line"> detail::<span class="built_in">unpack_call</span><R, <span class="keyword">sizeof</span>...(Args)>(flambda, args, rv);</span><br><span class="line"> });</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>AssignTypedLambda实际上是将被打包的函数先封装成了一个函数原型为void(constTVMArgs &args, TVMRetValue*rv)的lambda表达式,然后将这个lambda表达式作为PackedFunc对象的一个成员,通过设置合适的接口(重载operator()),使得PackedFunc与被打包的源函数表现的完全一样了。</p></li></ul><h3 id="自动导出函数">自动导出函数</h3><p>TVM将需要从C++自动导出的函数打包成PackedFunc,然后通过宏TVM_REGISTER_GLOBAL注册到全局的一个map中。比如:<figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="built_in">TVM_REGISTER_GLOBAL</span>(<span class="string">"_Var"</span>)</span><br><span class="line">.<span class="built_in">set_body_typed</span>([](std::string s, DataType t) {</span><br><span class="line"> <span class="keyword">return</span> VarNode::<span class="built_in">make</span>(t, s);</span><br><span class="line"> });</span><br></pre></td></tr></table></figure></p><p>当Python加载编译好的动态库时,会自动查询map中静态注册的函数,每个函数都包装成Python中的Function对象,最终添加到Python模块中。Function重定义了函数调用接口,自动完成参数打包过程。如果是在Python中动态注册的函数,则需要在Python中通过函数名和来查询PackedFunc,返回一个PackedFunc的handle(函数指针),并封装成Function。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">get_global_func</span>(<span class="params">name, allow_missing=<span class="literal">False</span></span>):</span><br><span class="line"> handle = FunctionHandle()</span><br><span class="line"> check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))</span><br><span class="line"> <span class="keyword">if</span> handle.value:</span><br><span class="line"> <span class="keyword">return</span> Function(handle, <span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> allow_missing:</span><br><span class="line"> <span class="keyword">return</span> <span class="literal">None</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">raise</span> ValueError(<span class="string">"Cannot find global function %s"</span> % name)</span><br></pre></td></tr></table></figure><p>注:TVMFuncGetGlobal是通过ctypes导出的C++接口,FunctionHandle是ctypes中表示void指针类型(c_void_p)。</p><h3 id="从python注册函数">从Python注册函数</h3><p>由于TVM中PackedFunc的精心设计,我们只需要将Python中的函数转换成统一的函数原型void(constTVMArgs,TVMRetValue),然后将函数转换成PackedFunc并动态地注册到全局的map中。</p><p>先将Python函数用ctypes转成int(TVMValue <em>, int </em>, int, void<em>, void </em>)的C函数。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">TVMPackedCFunc = ctypes.CFUNCTYPE(</span><br><span class="line"> ctypes.c_int,</span><br><span class="line"> ctypes.POINTER(TVMValue),</span><br><span class="line"> ctypes.POINTER(ctypes.c_int),</span><br><span class="line"> ctypes.c_int,</span><br><span class="line"> ctypes.c_void_p,</span><br><span class="line"> ctypes.c_void_p)</span><br></pre></td></tr></table></figure><p>然后通过TVMFuncCreateFromCFunc将上面的C函数转换成统一的PackedFunc函数。</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">int</span> <span class="title">TVMFuncCreateFromCFunc</span><span class="params">(TVMPackedCFunc func,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">void</span>* resource_handle,</span></span></span><br><span class="line"><span class="params"><span class="function"> TVMPackedCFuncFinalizer fin,</span></span></span><br><span class="line"><span class="params"><span class="function"> TVMFunctionHandle *out)</span> </span>{</span><br><span class="line"> <span class="built_in">API_BEGIN</span>();</span><br><span class="line"> <span class="keyword">if</span> (fin == <span class="literal">nullptr</span>) {</span><br><span class="line"> *out = <span class="keyword">new</span> <span class="built_in">PackedFunc</span>(</span><br><span class="line"> [func, resource_handle](TVMArgs args, TVMRetValue* rv) {</span><br><span class="line"> <span class="type">int</span> ret = <span class="built_in">func</span>((TVMValue*)args.values, (<span class="type">int</span>*)args.type_codes, <span class="comment">// NOLINT(*)</span></span><br><span class="line"> args.num_args, rv, resource_handle);</span><br><span class="line"> <span class="keyword">if</span> (ret != <span class="number">0</span>) {</span><br><span class="line"> <span class="keyword">throw</span> dmlc::<span class="built_in">Error</span>(<span class="built_in">TVMGetLastError</span>() + ::dmlc::<span class="built_in">StackTrace</span>());</span><br><span class="line"> }</span><br><span class="line"> });</span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> ...</span><br><span class="line"> }</span><br><span class="line"> <span class="built_in">API_END</span>();</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>最后通过接口TVMFuncRegisterGlobal注册到全局的map中。下面是从Python中注册一个函数,并在Python中调用的例子。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line">targs = (<span class="number">10</span>, <span class="number">10.0</span>, <span class="string">"hello"</span>)</span><br><span class="line"></span><br><span class="line"><span class="meta">@tvm.register_func</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">my_packed_func</span>(<span class="params">*args</span>):</span><br><span class="line"> <span class="keyword">assert</span>(<span class="built_in">tuple</span>(args) == targs)</span><br><span class="line"> <span class="keyword">return</span> <span class="number">10</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># Get it out from global function table</span></span><br><span class="line">f = tvm.get_global_func(<span class="string">"my_packed_func"</span>)</span><br><span class="line"><span class="keyword">assert</span> <span class="built_in">isinstance</span>(f, tvm.nd.Function)</span><br><span class="line">y = f(*targs)</span><br><span class="line"><span class="keyword">assert</span> y == <span class="number">10</span></span><br></pre></td></tr></table></figure>]]></content>
<summary type="html"><h2 id="tvm-packedfunc实现">TVM PackedFunc实现</h2>
<p>为了便于Python和C++混合编程,TVM使用了统一的PackedFunc机制。PackedFunc可以将C++中的各类函数打包成统一的函数接口,并自动导出到Python模块中进行调用,并且也支持从Python中注册一个函数,并伪装成PackedFunc在C++和Python中调用。</p></summary>
<category term="tvm knowledge" scheme="https://hjchen2.github.io/categories/tvm-knowledge/"/>
<category term="TVM" scheme="https://hjchen2.github.io/tags/TVM/"/>
<category term="PackedFunc" scheme="https://hjchen2.github.io/tags/PackedFunc/"/>
</entry>
<entry>
<title>图替换</title>
<link href="https://hjchen2.github.io/2019/12/26/%E5%9B%BE%E6%9B%BF%E6%8D%A2/"/>
<id>https://hjchen2.github.io/2019/12/26/%E5%9B%BE%E6%9B%BF%E6%8D%A2/</id>
<published>2019-12-26T05:54:04.000Z</published>
<updated>2023-02-07T02:37:14.493Z</updated>
<content type="html"><![CDATA[<h3 id="背景">背景</h3><p>图替换(或者叫图改写)是一种重要的图优化技术,几乎在所有的开源框架(尤其是移动端框架)中都有应用。比如tensorflowr1.14版本中就包含了155个替换子,而且实现这些替换子的总代码量接近53k行。</p><blockquote><p>一些常见的图优化技术:</p><ul><li><p>DCE</p></li><li><p>CSE(公共子表达式消除)</p></li><li><p>常量折叠</p></li><li><p>数学公式简化</p></li><li><p>Op融合</p></li><li><p>Layout变换</p></li><li><p>内存优化(swap-in/swap-out、重计算)</p></li></ul></blockquote><span id="more"></span><p>由于目前的编译器技术通常基于low-level的中间表达,注重对局部计算的优化,对于跨多个粗粒度op的优化要不无能为力,要不就得增加编译器的分析难度并导致代码膨胀。一般来说AI框架支持的粗粒度op非常有限,而且这些op的组合常常也比较固定,比如convolution通常和bias_add、relu组合使用,因此基于高层中间表达的图替换成为一种比较可行的优化方案。经过图替换优化后的计算图再经过编译器的优化后,生成最终的硬件代码。</p><p>目前主流开源框架的图替换都是基于经验和手工设置的替换子来实现的,在这里统称为经典图替换技术。</p><h3 id="经典图替换">经典图替换</h3><p>图替换是将原始计算图替换成另一个优化后的等价计算图,替换后的计算图通常是硬件友好的,比如可以消除中间结果,降低内存占用,减少访存和计算量,并且不影响最终的计算结果。</p><p>在进行图替换之前,首先需要定义出源计算图到目标计算图的替换规则(替换子),由于这些替换规则往往需要依靠人的经验一条条手工去定义,因此称之为经典图替换。给出一条替换子,我们需要在原始计算图中不断地去匹配替换子的源计算子图,一旦匹配到满足要求的子图后,就将源计算子图重新映射为替换子中的目标计算图。</p><p>在一些开源框架中,替换子的定义形式不尽相同。在TensorFlow中源图匹配和替换的定义是非常松散的,它甚至没有直接定义出替换子的源图,而是定义一系列约束来判断是否匹配。PaddlePaddle中则是将一个替换过程定义为一个pass,pass执行时动态构建相应的替换子源图,执行匹配算法并回调源图到目标图的替换函数。比如下面是TensorFlow中将Conv+BiasAdd替换成FusedConv的过程。</p><ul><li><p>定义匹配约束</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">struct</span> <span class="title class_">ContractionWithBiasAdd</span> {</span><br><span class="line"> <span class="type">int</span> constraction;</span><br><span class="line"> <span class="type">int</span> bias_add;</span><br><span class="line">}</span><br><span class="line"><span class="comment">// node为输入的grapper node, pattern为输出的ContractionWithBiasAdd.</span></span><br><span class="line"><span class="function"><span class="type">bool</span> <span class="title">FindContractionWithBias</span><span class="params">(node,*pattern)</span> </span>{</span><br><span class="line"> <span class="comment">// 开始列举匹配的constractions.</span></span><br><span class="line"> <span class="number">1</span>、如果node存在控制边,返回<span class="literal">false</span></span><br><span class="line"> <span class="number">2</span>、如果node不是BiasAdd,返回<span class="literal">false</span></span><br><span class="line"> <span class="number">3</span>、如果node的父节点不是Conv或MatMul,返回<span class="literal">false</span></span><br><span class="line"> <span class="number">4</span>、...</span><br><span class="line"> <span class="comment">// 如果以上所有constructions都满足,则将需要替换的node id写到特定的pattern中。</span></span><br><span class="line"> pattern->constraction = node的父节点;</span><br><span class="line"> pattern->bias_add = node;</span><br><span class="line"> <span class="keyword">return</span> <span class="literal">true</span>;</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li></ul><p><img src="https://github.com/hjchen2/personal/blob/master/blog/substitution/屏幕快照%202019-12-26%2011.03.21.png?raw=true" style="zoom:33%;" /></p><ul><li><p>定义替换过程</p><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// pattern为输入的ContractionWithBiasAdd,</span></span><br><span class="line"><span class="function"><span class="type">void</span> <span class="title">AddFusedContractionNode</span><span class="params">(pattern, *invalidated_nodes)</span> </span>{</span><br><span class="line"> <span class="number">1</span>、创建一个新的node:fused_op</span><br><span class="line"> <span class="number">2</span>、将Conv或MatMul的input和filter添加到fused_op的输入中,并将BiasAdd的bias加到fused_op的输入</span><br><span class="line"> <span class="number">3</span>、根据Conv或MatMul的一些参数设置fused_op的参数,比如conv的kernel、channel、padding等等,以及matmul的transpose等</span><br><span class="line"> <span class="number">4</span>、将fused_op加入到graph,同时将Conv或MatMul和BiasAdd加入到invalidated_nodes</span><br><span class="line">}</span><br></pre></td></tr></table></figure></li></ul><p><img src="https://github.com/hjchen2/personal/blob/master/blog/substitution/屏幕快照%202019-12-26%2011.03.35.png?raw=true" style="zoom:33%;" /></p><p>TensorFlow采用的定义匹配约束的方式与直接定义出子图的方式本质上是等价的,但相比后者可读性较差,而优点就是代码可复用性高,比如上面的FindContractionWithBias可以同时匹配Conv+BiasAdd和MatMul+BiasAdd两种子图,并且这些约束便于嵌套使用。</p><p>无论是TensorFlow还是PaddlePaddle,图替换都是不完全的。比如说对于Conv+BiasAdd+BiasAdd这种计算图,第一次只能匹配到Conv+BiasAdd,替换后又变成了一个Conv+BiasAdd的计算图,因此TensorFlow中默认采用了两遍优化。根据TensorFlow公开的一些数据,基本上第二次优化的机会已经非常少了。</p><ul><li><p>InceptionV3</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/substitution/屏幕快照%202019-12-26%2011.29.51.png?raw=true" style="zoom:40%;" /></p></li><li><p>Seq2Seq</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/substitution/屏幕快照%202019-12-26%2011.30.01.png?raw=true" style="zoom:40%;" /></p></li></ul><h3 id="基于超优化的图替换">基于超优化的图替换</h3><p>超优化(Superoptimization)是现代编译器中的一种指令优化技术,其主要工作原理是通过随机生成指令序列以及暴力搜索的方式自动找到一组优化的指令序列,并等价替换原有的指令序列。1992年第一个Superoptimizer被集成到了GCC编译器,之后Google也为LLVM开发了一个Superoptimizer,取名为Souper。</p><p>依靠人工设定的编译器往往对代码的优化不够彻底,给生成的code留下了大量的优化空隙,而且人工设定的优化规则往往没有经过充分验证,经常导致各种极端条件下的代码bug。Superoptimization将指令序列优化问题转换为自动搜索问题,并加入了自动化验证和一阶逻辑验证,在发现代码优化空隙的同时优化结果也更加可靠。</p><p><a href="https://github.com/jiazhihao/TASO">TASO</a>(Tensor AlgebraSuperOptimizer)将Superoptimization用于DNN高层中间表达的图优化,在大多数模型上取得了比XLA和TensorRT更优的效果。TASO的工作是MetaFlow(作者另一个基于人工规则的图替换框架)的延续,因此也采用了与MetaFlow一致的替换子定义。MetaFlow替换子的定义包括:源图、目标图、输入和输出的映射关系。</p><p><img src="https://github.com/jiazhihao/TASO/blob/master/figures/inference.png?raw=true"></p><p>TASO相比其他开源框架最大的区别就是不需要手工去设定各种各样的替换子,只需要像设计硬件指令一样设计出基本的算子定义(或者计算逻辑),之后系统会根据指定的算子集自动生成满足条件的替换子,经过验证的替换子最终作用于图替换过程。基于高度抽象的替换子定义,TASO可以独立于具体的训练或预测框架,离线完成替换子的生成和验证,并在图优化阶段加载到程序中进行图替换。尽管手工设计有很多弊端,但TASO在代码实现过程中并没有完全抛弃手工设计的方式,而是采用了手工设计和替换子自动生成相结合的方式。</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/substitution/屏幕快照%202019-12-29%2013.56.06.png?raw=true" style="zoom:40%;" /></p><h4 id="替换子定义">替换子定义</h4><p>替换子包含三个部分,源图、目标图、输入和输出tensor的映射关系。并且替换子通常是与shape无关的,源图和目标图都是由算子构成的,每个算子都可以指定一些配置,比如kernel指定卷积核的大小、axis指定reduce的维度等等。</p><p>但需要注意的是concat和split两个算子,在图替换中这两个算子通常用于算子融合,比如下图对两个不同的输入B和C进行相同的MatMul操作,就可以替换为先将输入B和C进行一次合并,然后调用一次MatMul后,对结果进行切分得到两个输出X和Y。</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/substitution/屏幕快照%202019-12-29%2016.24.40.png?raw=true" style="zoom:50%;" /></p><p>为了能正确切分出X和Y,在Concat时我们需要给每个维度维护一个分割树(splittree)。一个行分割的例子如下,图中需要将A和B按照第0维进行concat,因此输入A在第0维有一个原始的分割树[0,<span class="math inline">\(S_{A}\)</span>],表示对于tensorA,第0维从0到<spanclass="math inline">\(S_{A}\)</span>行都是A的数据区域。A和Bconcat后tensor的row变成了<spanclass="math inline">\(S_{A}+S_{B}\)</span>,并且通过分割树可以知道第0到<spanclass="math inline">\(S_{A}\)</span>行是A的数据,从<spanclass="math inline">\(S_{A}\)</span>到<spanclass="math inline">\(S_{A}+S_{B}\)</span>行是B的数据。根据分割树,Split非常容易地就可以将数据进行切分。TASO的分割树支持任意维度的切分和递归切分。</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/substitution/屏幕快照%202019-12-29%2016.37.22.png?raw=true" style="zoom:70%;" /></p><h4 id="替换子生成">替换子生成</h4><p>替换子生成包含两个阶段:构建搜索空间,以及对潜在的替换子进行测试。</p><ul><li><p>构建搜索空间</p><p>搜索空间由任意合法的计算图构成,而计算图由给定的算子集中的算子组成。TASO向我们表明了一种暴力枚举、深度优先递归构建的方法。</p><p>给定算子集和初始化的inputtensor集合,对于每一个输入tensor,每次从算子集中选择一个合法的算子构建graph,并计算当前graph的输出tensor,将输出tensor加入到inputtensor集合,保存graph以及graph的fingerprint(对输出tensor计算hash值),接着重复上面的过程继续加入算子,直到递归的深度达到设定的上限。</p><p>对于同样的输入tensor,如果构建的两个计算图的输出tensor相同,则这两个计算图构成了一个潜在的替换子。为了避免出现浮点计算异常的情况,构建计算图时所有的tensor都是int类型。</p></li><li><p>测试潜在替换子</p><p>为了进一步验证潜在替换子的合法性,TASO设计了一系列cases来测试潜在替换子。每个测试case都使用随机初始化的输入tensor,当两个计算图结果一致时才认为测试通过,只有所有测试cases都通过的潜在替换子才是合法的替换子。</p><p>与构建计算图时使用int类型的tensor不一样,所有测试case的输入tensor都是-1到1之间的浮点数。由于relu对于所有小于0的值都返回0,因此可能导致非法的替换子通过测试cases,作者认为可以使用任意一个非线性函数来代替relu,TASO中使用<spanclass="math inline">\(x(x+1)+1\)</span>。</p></li></ul><p><img src="https://github.com/hjchen2/personal/blob/master/blog/substitution/屏幕快照%202019-12-30%2016.30.05.png?raw=true" style="zoom:50%;" /></p><h4 id="替换子验证">替换子验证</h4><p>TASO同时使用一阶逻辑表达的算子属性对替换子进行进一步验证,这些属性通常是由人工定义,并且经过充分review和大量测试验证过的。</p><p>在定义算子属性之前,首先需要对算子进行符号建模,算子模型通常包含参数和输入tensors。比如<spanclass="math inline">\(conv(s, p, c, x,y)\)</span>表示conv算子的符号模型,<spanclass="math inline">\(s\)</span>,<spanclass="math inline">\(p\)</span>,<spanclass="math inline">\(c\)</span>是conv的参数,分别表示stride、padding和activation,<spanclass="math inline">\(x\)</span>和<spanclass="math inline">\(y\)</span>是卷积操作的两个输入。如果activation是none,很显然conv就是一个线性操作,因此满足以下属性:<span class="math display">\[\begin{aligned}∀s,p,x,y,z. conv(s,p,Anone,ewadd(x,y),z) = \\ewadd(conv(s,p,Anone,x,z),conv(s,p,Anone,y,z))\end{aligned}\]</span>TASO定义了大量的算子属性,并且使用z3(一阶逻辑验证器)对所有合法的替换子进行验证。如果有合法的替换子无法被一阶逻辑验证,则需要根据替换子手动添加一条算子属性,以确保所有合法的替换子都能验证通过。</p><h4 id="冗余替换子裁剪">冗余替换子裁剪</h4><p>自动生成的替换子往往存在大量的冗余,TASO使用了两种策略消除冗余。</p><ul><li><p>Input tensor renaming</p><p>对输入进行重命名的方式消除不同替换子之间的冗余。比如下面两个替换子,</p><p>替换子a:</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/substitution/屏幕快照%202019-12-30%2018.14.31.png?raw=true" style="zoom:40%;" /></p><p>替换子b:</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/substitution/屏幕快照%202019-12-30%2018.15.49.png?raw=true" style="zoom:40%;" /></p><p>将替换子a的一个输入tensorA改名为C,就得到了替换子b,说明这两个替换子存在冗余,因此最终只会保留更加通用的替换子b。</p></li><li><p>Common subgraph</p><p>如果替换子的源图和目标图包含同样的子图,则可以用一个相同的tensor替换掉公共子图。比如下面的一个替换子,</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/substitution/屏幕快照%202019-12-30%2018.14.59.png?raw=true" style="zoom:40%;" /></p><p>source graph和target graph包含同一个子图(B x C),将sourcegraph替换成targetgraph时,公共子图没有任何变化,因此可以将子图消除。</p></li></ul><p>实验结果显示,裁剪可以消除大量的冗余替换子。</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/substitution/屏幕快照%202019-12-30%2018.12.39.png?raw=true" style="zoom:50%;" /></p><h3 id="低精度和layout优化">低精度和layout优化</h3><h3 id="相关资料">相关资料</h3><ol type="1"><li>https://cs.stanford.edu/~zhihao/papers/sosp19.pdf<br /></li><li>https://github.com/jiazhihao/TASO<br /></li><li>TensorFlow Graph Optimizations,https://web.stanford.edu/class/cs245/slides/TFGraphOptimizationsStanford.pdf<br /></li><li>https://github.com/google/souper</li></ol>]]></content>
<summary type="html"><h3 id="背景">背景</h3>
<p>图替换(或者叫图改写)是一种重要的图优化技术,几乎在所有的开源框架(尤其是移动端框架)中都有应用。比如tensorflow
r1.14版本中就包含了155个替换子,而且实现这些替换子的总代码量接近53k行。</p>
<blockquote>
<p>一些常见的图优化技术:</p>
<ul>
<li><p>DCE</p></li>
<li><p>CSE(公共子表达式消除)</p></li>
<li><p>常量折叠</p></li>
<li><p>数学公式简化</p></li>
<li><p>Op融合</p></li>
<li><p>Layout变换</p></li>
<li><p>内存优化(swap-in/swap-out、重计算)</p></li>
</ul>
</blockquote></summary>
<category term="graph optimization, 图优化" scheme="https://hjchen2.github.io/categories/graph-optimization-%E5%9B%BE%E4%BC%98%E5%8C%96/"/>
<category term="图替换" scheme="https://hjchen2.github.io/tags/%E5%9B%BE%E6%9B%BF%E6%8D%A2/"/>
<category term="超优化" scheme="https://hjchen2.github.io/tags/%E8%B6%85%E4%BC%98%E5%8C%96/"/>
<category term="graph optimization" scheme="https://hjchen2.github.io/tags/graph-optimization/"/>
<category term="super optimization" scheme="https://hjchen2.github.io/tags/super-optimization/"/>
<category term="substitution" scheme="https://hjchen2.github.io/tags/substitution/"/>
</entry>
<entry>
<title>FusionStitching, Deep Fusion and Code Generation for Tensorflow Computations on GPUs</title>
<link href="https://hjchen2.github.io/2019/11/27/DeepFusion/"/>
<id>https://hjchen2.github.io/2019/11/27/DeepFusion/</id>
<published>2019-11-27T04:00:04.000Z</published>
<updated>2023-02-07T02:39:43.223Z</updated>
<content type="html"><![CDATA[<h2 id="fusionstitching系统概述">FusionStitching系统概述</h2><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/DeepFusion/屏幕快照%202019-11-25%2013.56.40.png?raw=true"alt="屏幕快照 2019-11-25 13.56.40" /><figcaption aria-hidden="true">屏幕快照 2019-11-25 13.56.40</figcaption></figure><p>输入HloModule,经过以下三个阶段,最终输出LLVM IR。</p><ul><li>Computation Fusion</li><li>Schedule Planning</li><li>Code Generation</li></ul><p>论文主要针对XLAFusion算法进行了改进,提出了实现Block合并策略的Schedule和Shared MemoryPlanning技术,以及实现对应的IR Emitter。</p><span id="more"></span><h2 id="computation-fusion">Computation Fusion</h2><p>利用Work/Span analysis,将instruction划分到不同的layer,然后DeepFusion模块在Schedule ConsistencyChecker的指导下完成跨layer的instruction合并。该过程是迭代进行的,直到完全没有合并机会。</p><h3 id="workspan-analysis">Work/Span analysis</h3><blockquote><p>Work/Spananalysis通常用于并行算法的分析。假设每个基本运算执行时间都是单位时间,则Work表示的是所有基本运算时间总和,Span表示最长依赖路径上的基本运算时间总和。对于一个计算图来说,可以简单认为图中所有的计算节点总执行时间表示Work,而计算图的最大深度的路径上的节点的顺序执行总时间表示Span。</p></blockquote><p>在这里作者用Span来表示每个节点到root节点的深度。</p><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/DeepFusion/屏幕快照%202019-11-26%2018.28.17.png?raw=true"alt="屏幕快照 2019-11-26 18.28.17" /><figcaption aria-hidden="true">屏幕快照 2019-11-26 18.28.17</figcaption></figure><p>经过Work/Spananalysis后,HloModule中的Instruction被划分到了不同的layer,相同Span值的Instruction的layer相同,并且同一layer的Instruction没有依赖关系。</p><h3 id="subgraph-fusion-algorithm">Subgraph Fusion Algorithm</h3><p>基于Work/Spananalysis计算得到的Span值,作者提出了不同于XLA的Fusion算法。</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/DeepFusion/屏幕快照%202019-11-26%2019.44.59.png?raw=true" width=600/></p><p>SchdConsistent用来判断fusion_root和hlo是否应该合并,其具体的执行逻辑如下:</p><ol type="1"><li>如果hlo有一个consumer在giveup集合中,为了防止潜在的循环依赖,放弃fusion。</li><li>如果hlo的所有consumer都不在fused集合中,则放弃fusion,因为这里只考虑Producer/Consumer的合并,没有消费关系的Instruction合并则会在ElementwiseFusion算法中处理。</li><li>最后会判断合并后的Computation是否存在一个可行的optimizedshedule。如果不存在,则放弃fusion。</li></ol><blockquote><ul><li>算法简单高效,Work/SpanAnalysis的作用其实相当于对Instruction做了一遍拓扑排序,通过简单的合并规则确保circlefree。</li><li>不区分expensive op,可以通过sharedmemory来缓存中间结果,因此不需要重计算。</li><li>由于第一条约束的强制性,导致合并不完全。</li></ul></blockquote><h2 id="schedule-planning">Schedule Planning</h2><h3 id="schedule定义">Schedule定义</h3><blockquote><p>Schedule通常指的是将算法指定的计算过程分配给计算资源的方法。这些计算过程可能包括线程、进程以及数据流等。</p><p>常见的一些Schedule配置: - Reorder 循环顺序重排,比如for x for y ->for y for x - Tile - Unroll - Vectorize - Parallel - some CUDA-specific比如blocks、threads、shared memory size等。</p><p>对于包含多个计算stage的算法,Schedule通常是由Consumer驱动,并指定何时对Consumer计算Producer(<strong>Specifywhen the producer is computed with respect to the consumer</strong>)。</p></blockquote><p>论文中将Instruction大致分成Elementwise、Transpose、Reduce、BatchDot、Reshape和Broadcast这几种,然后基于这些op定义了一套用来表示对数据分块的Shedule配置。通过一个定义好的Shedule配置和数据的shape,我们就可以知道需要切成多少个数据块,映射到GPU硬件上就是多少个线程块(threadblocks)。</p><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/DeepFusion/屏幕快照%202019-11-27%2011.22.57.png?raw=true"alt="屏幕快照 2019-11-27 11.22.57" /><figcaption aria-hidden="true">屏幕快照 2019-11-27 11.22.57</figcaption></figure><p>Shedule定义在输出shape上,包含三个字段:split_dim、sword和sched_type。split_dim表示切割的维度,取值[0,num_dims)。sword表示在split_dim维度上切分多少块,sword要求能被split_dim维度K整除。sched_type表示行切割还是列切割,取值Row或者Column。给定一个Instruction,其Schedule空间即所有合法的三元组(split_dim、sword和sched_type)。</p><p>上图表示Reduce Instruction的两种合法Schedule,通过split_dim和reducedim来区分Row Schedule和Column Schedule。</p><h3 id="schedule约束和传播">Schedule约束和传播</h3><p>与Instruction的Schedule定义在输出shape上一样,Computation的Schedule也定义在RootInstruction的输出上,因为Root Instruction是整个Computation的输出。<br />对于一个Fused Computation,需要满足Shedule相容约束:即对于RootInstruction,给定一个合法的Shedule,该Shedule需要同时被其他Instruction相容。论文中提出后向传播的方法来判断Shedule约束的相容性。<br />任意一个Instruction,其合法的Schedule可以根据Instruction类型和outputshape来确定。如果给定的Schedule满足约束(是合法的),则把Schedule后向传播到输入shape(s),也就是输入Instruction的输出shape。否则从RootInstruction传播过来的Schedule在整个FusedCompution上不满足相容性约束。</p><blockquote><p>在Subgraph Fusion算法中,两个Instruction能否合并除了需要满足circlefree约束外,还需要满足后端CodeGen模块的支持。只有Schedule满足约束,CodeGen才能正确发射代码,否则CodeGen无法处理。</p></blockquote><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/DeepFusion/屏幕快照%202019-11-27%2013.53.21.png?raw=true"alt="屏幕快照 2019-11-27 13.53.21" /><figcaption aria-hidden="true">屏幕快照 2019-11-27 13.53.21</figcaption></figure><p>Table.1表明了不同Instruction的Schedule后向传播规则。Schedule约束判断结果会反馈到SubgraphFusion过程,Fusion不应该破坏Schedule相容性约束。</p><h3 id="schedule-tuning">Schedule Tuning</h3><p>任意一个Instruction,split_dim=0和sword=1的RowSchedule总是合法的,也就是只有一个数据块,并且只用一个GPU线程块来计算。这样做的问题也很明显,就是无法充分利用GPU硬件资源。每个Instruction可能有多个合法的Schedule,ScheduleTuning用来选择一个合适的Schedule。<br />如果Computation中只有一个Root,遍历该RootInstructon所有合法的满足约束的Schedule,在performancelibrary中查找每个kernel的执行时间,并统计总耗时。总耗时最少的Schedule会被选择用来CodeGeneration。</p><p>如果Computation中有多个Roots,则采取一种two-stage的方法加速Schedule的搜索过程。<br />第一步:遍历所有的Roots,计算blocks和blocks对应的Schedule两个序列。对所有Roots对应的blocks序列求交集,结果对应的Schedule即合法的候选Schedule。<br />第二步:遍历所有的候选Schedule,计算每个Schedule下所有kernel的耗时,选择耗时最少的Schedule。论文中还提到可以忽略部分ops和earlystop的搜索策略,加速第二步的搜索过程。</p><h2 id="code-generation">Code Generation</h2><h3 id="shared-memory-planning">Shared Memory Planning</h3><p>标记出所有可能需要用到SharedMemory的候选ops,当Memory不足时优先满足most critical ops。</p><ul><li><p>Size Requirement Analysis</p><ol type="1"><li><p>直接分配 对于非RootInstruction的Reduce和BatchDot,必须将中间结果放在Shared Memory,allowingconsumer ops to use seperate parallel loop emitters to generatecode。</p></li><li><p>按优先级分配 对于有多个Users的Elementwiseop,为了避免重计算,可以选择将结果缓存到SharedMemory。在memory受限的情况下,按照优先级(expensive op > 非expensiveop)确定使用Shared Memory。<br />有时对于只有一个User的expensive op也需要用到SharedMemory,比如如果expensiveop后面接了一个BatchDot,由于BatchDot本身对数据的复用性比较高,将expensiveop的结果缓存到Shared Memory会带来非常好的性能优化。</p></li></ol></li><li><p>Size Shrinking</p><p>Size Shrinking与上面RequirementAnalysis的第2点类似。当每个线程Block分到的数据块非常大时,可能存在SharedMemory受限的问题。解决办法就是让一些ops退化为重计算。<br />从inexpensive ops开始,然后expensive ops,之后是带有BtachDot的expensiveops,最后按照靠近RootInstruction的程度选择候选ops,并优先选择靠近输出的ops。</p></li><li><p>Space Sharing</p><p>不同ops分配的Shared Memory是可以复用的,论文中作者提出从RootInstruction开始构造一颗支配树,支配节点可以复用被支配节点申请的SharedMemory。</p></li></ul><h3 id="code-generation-1">Code Generation</h3><p>XLA使用GpuElementalIrEmitter来实现线程合并的Computation。基于XLA的GpuElementalIrEmitter,作者实现了用于Block合并的IrEmitter(论文中称作IrEmitterStitched)。</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/DeepFusion/屏幕快照%202019-11-27%2017.26.12.png?raw=true" width=600/></p><p>IrEmitterStitched输入有hlo、root、shared、schedule和generators。</p><ul><li>hlo: 待处理的hlo Instruction</li><li>root: 是否是root Instruction</li><li>shared: 是否将输出写到shared memory</li><li>shedule: row schedule还是column schedule</li><li>generators:与XLAGpuElementalIrEmitter中的generators_类似,但是能处理sharedmemory的情况。</li></ul><p>基本逻辑如下:</p><ol type="1"><li>如果待处理的Instruction不是root Instruction,并且没有用到SharedMemory,不是Dot和ReduceOpcode,则回退到XLA的GpuElementalIrEmitter中去处理,否则使用IrEmitterStitched发射LLVM代码。</li><li>如果需要用到SharedMemory,则调用EmitWriteSharedArray将结果写到Shared Memory。</li><li>如果是root Instruction,则调用EmitWriteOutputArray将结果写到GlobalMemory。如果不是rootInstruction,则调用EmitGenerator在generators中创建一个entry,以支持当前Instruction与其他Instruction的合并。</li></ol><h2 id="xla-op-fusion规则">XLA op fusion规则</h2><ul><li><p>Consumer本身支持合并</p><p>特定op不支持与Producer合并,比如Parameter、While、Conditional、Call等,以及op本身hasa side effect或者调用了has a sideeffect的op。此外被标记为tracing的op也无法合并。</p></li><li><p>Consumer与Producer之间支持合并</p><ul><li>Consumer和Producer之间所有的op均可以被合并到Consumer。</li><li>对于Consumer和Producer之间所有的op:<ol type="1"><li>如果直接Producer已经是一个Fusion op,则不能合并。</li><li>对Reduce和Scatter,以及CustomCall/LibraryCall的一些限制。</li><li>如果直接Producer有其他Consumer,则Fusion会导致该Producer需要重计算。如果Producer属于expensive op或为Parameterop则放弃合并。</li></ol></li></ul></li></ul>]]></content>
<summary type="html"><h2 id="fusionstitching系统概述">FusionStitching系统概述</h2>
<figure>
<img
src="https://github.com/hjchen2/personal/blob/master/blog/DeepFusion/屏幕快照%202019-11-25%2013.56.40.png?raw=true"
alt="屏幕快照 2019-11-25 13.56.40" />
<figcaption aria-hidden="true">屏幕快照 2019-11-25 13.56.40</figcaption>
</figure>
<p>输入HloModule,经过以下三个阶段,最终输出LLVM IR。</p>
<ul>
<li>Computation Fusion</li>
<li>Schedule Planning</li>
<li>Code Generation</li>
</ul>
<p>论文主要针对XLA
Fusion算法进行了改进,提出了实现Block合并策略的Schedule和Shared Memory
Planning技术,以及实现对应的IR Emitter。</p></summary>
<category term="DL Compiler" scheme="https://hjchen2.github.io/categories/DL-Compiler/"/>
<category term="XLA" scheme="https://hjchen2.github.io/tags/XLA/"/>
<category term="Deep Learning Compiler" scheme="https://hjchen2.github.io/tags/Deep-Learning-Compiler/"/>
<category term="FusionStitching" scheme="https://hjchen2.github.io/tags/FusionStitching/"/>
</entry>
<entry>
<title>混合精度训练</title>
<link href="https://hjchen2.github.io/2018/02/03/%E6%B7%B7%E5%90%88%E7%B2%BE%E5%BA%A6%E8%AE%AD%E7%BB%83/"/>
<id>https://hjchen2.github.io/2018/02/03/%E6%B7%B7%E5%90%88%E7%B2%BE%E5%BA%A6%E8%AE%AD%E7%BB%83/</id>
<published>2018-02-03T04:00:04.000Z</published>
<updated>2023-02-07T02:39:53.138Z</updated>
<content type="html"><![CDATA[<h2 id="mixed-precision-training">MIXED PRECISION TRAINING</h2><p><ahref="https://email.baidu.com/OWA/redir.aspx?C=G_TpaBQZHjfotfty5PDuHfO3av_KUOGPcZOg_60U2vdUx9QS42vVCA..&URL=https%3a%2f%2farxiv.org%2fpdf%2f1710.03740.pdf">https://arxiv.org/pdf/1710.03740.pdf</a></p><h3 id="论文概述">论文概述</h3><p>nvidia的Pascal和Volta系列显卡除了支持标准的单精度计算外,也支持了低精度的计算,比如最新的TeslaV100硬件支持了FP16的计算加速,P4和P40支持INT8的计算加速,而且低精度计算的峰值要远高于单精浮点的计算峰值。</p><span id="more"></span><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/95247900845ca0aa285aea86b971c6ec.png?raw=true'></p><p>为了加速训练过程以及减少显存开销,baiduResearch和nvidia在这篇论文中合作提出了一种FP16和FP32混合精度训练的方法,并且在CNN分类和检测、语音识别和语言模型任务上进行了验证,实验过程中使用的GPU就是TeslaV100。</p><p>训练过程中每层的权重都存成FP32格式(Mater-Weights),每次训练时都会将FP32的权重降精度至FP16(a mastercopy),前向输出和后向梯度都使用FP16进行计算,更新时将FP16的梯度累加到FP32的Mater-Weight上。</p><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/b89a595f09deb2caf14d44176f931440.png?raw=true'></p><h3 id="混合精度的必要性">混合精度的必要性</h3><p>由于FP16所能表示的subnormal最小正数是<spanclass="math inline">\(2^{−24}\)</span> ≈ <spanclass="math inline">\(5.96 × 10^{−8}\)</span>(<ahref="https://en.wikipedia.org/wiki/Half-precision_floating-point_format">Half-precisionfloating-point format</a>),也就是说在区间(<spanclass="math inline">\(-2^{-24},2^{-24}\)</span>)的数(或者说指数位小于-24的数)使用FP16表示时都会变成0。在一个普通话识别的模型训练中,有将近5%的权重梯度的指数位小于-24,如果更新时也用FP16计算,那么这些数在乘以学习率后都将变成0,从而对最终模型效果产生负面影响,使用混合精度训练的方式可以避免这种问题。</p><h3 id="loss-scaling">Loss scaling</h3><p>混合精度训练可以解决权重更新量很小的问题,但无法解决梯度本身很小的问题。在一些网络中(比如SSD),梯度大部分都在FP16的表示范围之外,因此需要将梯度平移到FP16的表示范围内。</p><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/fc960bb10d950d111404cda831aa5cbe.png?raw=true'></p><p>平移实际上就是对梯度值乘以一个系数(等于<spanclass="math inline">\(2^{n}\)</span>,<spanclass="math inline">\(n\)</span>为平移的位数),但另一种简单高效的方法是直接在前向时就将loss乘以scale,这样在后向传导时所有的梯度都会被乘以相同的scale。权重更新时需要将移位后的梯度除以scale后,再更新到权重上。</p><p>论文中提到他们在实验过程中使用的scale是8~32K,最终取得了与FP32一致的收敛结果。对于scale的选择,论文没有统一的方法,只是提到scale并没有下界,只要选择的scale不会在后向计算时导致溢出就行。</p><h3 id="实验结果">实验结果</h3><ul><li><p>图像分类</p><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/a9166bfb03d36772c83f4aa56e591374.png?raw=true'></p></li><li><p>物体检测</p><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/3dbc1922becd3b150d50bc71aacecb1e.png?raw=true'></p></li><li><p>语音识别</p><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/0369372f891c65571c845b04960aafda.png?raw=true'></p></li><li><p>机器翻译</p><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/237914e80a50fe0f2cac573c36733e5c.png?raw=true'></p></li><li><p>语言模型</p><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/f1c1f41006c8f637c29208ac8652310b.png?raw=true'></p><p></p></li></ul><h2id="mixed-precision-training-of-convolutional-neural-networks-using-integer-operations">MIXEDPRECISION TRAINING OF CONVOLUTIONAL NEURAL NETWORKS USING INTEGEROPERATIONS</h2><p><ahref="https://email.baidu.com/OWA/redir.aspx?C=a0s4Pl45ENd9uqHgfl_L2eKY-IGy51CKRbN_JHdP0YhUx9QS42vVCA..&URL=https%3a%2f%2fopenreview.net%2fforum%3fid%3dH135uzZ0-">https://openreview.net/forum?id=H135uzZ0-</a></p><h3 id="论文概述-1">论文概述</h3><p>半精度(16bit)分为半精度浮点(FP16)和半精度定点(INT16),FP16和INT16提供不同的精度和表示范围。INT16相比FP16的动态范围低,但精度更高,因此INT16相比FP16会带来更低的精度误差。</p><p>现在深度学习领域公认的数据类型是单精度浮点(float),半精和单精除了在直观感觉上的数据类型不同之外,在计算(algorithmic)和语义(semantic)上也会有很多的不同,比如说FP16的乘加操作得到的结果是FP32。因此在讨论半精度训练时,对于整个tensor的表达、乘加操作、低精度转换、缩放和规整方法和溢出处理都是需要同时考虑的。</p><p>intel的这篇论文主要受到之前flexpoint和混合精度训练的启发,从而提出了一种共享指数位的动态定点表达(dynamicfixed pointrepresentation)方法,使用INT16和float混合精度训练,在完全不进行任何调参的情况下,在多个CNN的模型上取得了当前所有低精度训练方法中最好的效果。</p><p>这篇论文主要涉及的技术点有:</p><ul><li>DFP:INT16的Tensor共享指数位,扩充INT16的动态表示范围。</li><li>instruction:两个INT16进行乘法,结果存为INT32的指令。</li><li>down-convert:基于最大值的低精度转换策略,使用nearest、stochastic和biasedrounding三种不同的rounding方法。</li><li>overflowmanagement:将局部的INT32结果累加到FP32,防止累加时溢出。</li></ul><h3 id="dfpdynamic-fixed-point">DFP(Dynamic Fixed Point)</h3><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/f54c9019a7174299761d48094d1f0dab.png?raw=true'></p><p>一个DFPtensor由一个定点的tensor和该tensor共享的指数组成,更通用的表示形式为DFP-P= <span class="math inline">\(<I, E_{s}>\)</span>,P表示定点tensor<span class="math inline">\(I\)</span>的位宽,<spanclass="math inline">\(E_{s}\)</span>表示共享指数位。标准单精使用的是8bit的指数位,在该论文中使用的DFP-16共享指数位也是8bit。</p><ul><li><p>DFP-16和fp32的数据转换</p><p>共享指数位需要根据tensor中的绝对值最大的数和定点化的位宽来确定,计算公式如下:</p><p><span class="math display">\[E_{s} = E_{fmax} - (P - 2)\]</span></p><p><span class="math inline">\(E_{s}\)</span>表示DFP-P的共享指数,<spanclass="math inline">\(E_{fmax}\)</span>表示原始fp32tensor中绝对值最大的数对应的指数<span class="math inline">\(E_{fmax} =E(max_{\forall f \in F} |f|)\)</span></p><p>因此fp32的tensor与DFP的tensor有以下关系:</p><p><span class="math display">\[\forall i_{n} \in I, \ \ \ f_{n} = i_{n}\times 2^{E_{s}}, \ \ \ where f_{n} \in F\]</span></p><p>也就是说<span class="math inline">\(i_{n} =rounding(\frac{f_{n}}{2^{E_{s}}})\)</span>,这本质上与lossscaling思想是一样的,用平移的思想来解决动态范围不够的问题。</p></li><li><p>DFP-16 tensor的乘加运算规则</p><p>1、两个DFP-16 tensor相乘,结果存为DFP-32。</p><p><span class="math display">\[i_{ab} = i_{a} \times i_{b} , \ \ \E_{s}^{ab} = E_{s}^{a} + E_{s}^{b}\]</span></p><p>2、两个DFP-16 tensor相加,结果存为DFP-32。</p><p><span class="math display">\[i_{ab} = \left\{\begin{aligned} i_{a} +(i_{b} >> (E_{s}^{a} - E_{s}^{b})) \ \ \ when E_{s}^{a} >E_{s}^{b} \\ i_{b}+(i_{a} >> (E_{s}^{b}-E_{s}^{a})) \ \ \ whenE_{s}^{a} < E_{s}^{b} \end{aligned}\right.\]</span></p><p><span class="math display">\[E_{s}^{a+b} = max(E_{s}^{a},E_{s}^{b})\]</span></p><p>3、两个DFP-32 tensor相加,结果保存为fp32。</p></li><li><p>DFP-32和DFP-16的数据转换</p><p><span class="math display">\[R_{s} = P - LZC(max_{\forall i_{ab} \inI^{32}}|i_{ab}|)\]</span></p><p><span class="math display">\[i_{ab}^{d} = i_{ab} >> R_{s} , \ \\ E_{s}^{ab} += R_{s}\]</span></p></li></ul><h3 id="dfp混合精度训练">DFP混合精度训练</h3><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/9b304e74b8dbc9ec6324c56d05b85f24.png?raw=true'></p><h3 id="指令实现">指令实现</h3><p>intel的VNNI指令集中有一条DFP-16乘加的指令QVNNI16,这条指令的第一个操作数是DFP-16内存指针,第二个操作数是4个512位的向量寄存器(每个寄存器可以存储32个DFP-16),结果是一个512位的向量寄存器(该寄存器能存储16个DFP-32)。</p><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/4a18c89da9676673a73c899987564e10.png?raw=true'></p><p>上面的QVNNI16指令集实际上对mem输入做了两路并行展开,vinp2中一个寄存器支持同时对输入featuremap的两个channel进行计算。在论文中,卷积层输入的格式为(N,C/16,H,W,16),权重的格式为(C/16,K/16,KH,KW,8c,16k,2c),C表示输入featuremap的通道数,K表示输出通道数,KH和KW分别表示卷积核的height和width。</p><p>卷积计算过程伪代码:</p><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/2f5405a955c03cd522b5b1f17e7300cd.png?raw=true'></p><p>每次对输入的ICBLK个通道进行计算,ICBLK个通道又会分成(ICBLK/16)组,每组计算16个通道,由于QVNNI指令每次只能对输入的8个通道进行计算,因此每组调用2次QVNNI16指令,计算结果vout会转换成FP32后与output累加。</p><h3 id="实验结果-1">实验结果</h3><p>baseline和DFP-16的实验均在intel最新的Knights-MillCPU上进行,DFP-16相比FP32训练加速1.8X。</p><p><img src='https://github.com/hjchen2/personal/blob/master/blog/mixed-precision/55d321517c2de03fe92f7c32aff1d87a.png?raw=true'></p><h3 id="abs_max量化方案">ABS_MAX量化方案</h3><h3 id="dfp与abs_max量化的区别">DFP与ABS_MAX量化的区别</h3>]]></content>
<summary type="html"><h2 id="mixed-precision-training">MIXED PRECISION TRAINING</h2>
<p><a
href="https://email.baidu.com/OWA/redir.aspx?C=G_TpaBQZHjfotfty5PDuHfO3av_KUOGPcZOg_60U2vdUx9QS42vVCA..&amp;URL=https%3a%2f%2farxiv.org%2fpdf%2f1710.03740.pdf">https://arxiv.org/pdf/1710.03740.pdf</a></p>
<h3 id="论文概述">论文概述</h3>
<p>nvidia的Pascal和Volta系列显卡除了支持标准的单精度计算外,也支持了低精度的计算,比如最新的Tesla
V100硬件支持了FP16的计算加速,P4和P40支持INT8的计算加速,而且低精度计算的峰值要远高于单精浮点的计算峰值。</p></summary>
<category term="low bitwidth" scheme="https://hjchen2.github.io/categories/low-bitwidth/"/>
<category term="int16" scheme="https://hjchen2.github.io/tags/int16/"/>
<category term="fp16" scheme="https://hjchen2.github.io/tags/fp16/"/>
<category term="混合精度训练" scheme="https://hjchen2.github.io/tags/%E6%B7%B7%E5%90%88%E7%B2%BE%E5%BA%A6%E8%AE%AD%E7%BB%83/"/>
<category term="loss scaling" scheme="https://hjchen2.github.io/tags/loss-scaling/"/>
<category term="QVNNI16" scheme="https://hjchen2.github.io/tags/QVNNI16/"/>
</entry>
<entry>
<title>模型压缩之pruning</title>
<link href="https://hjchen2.github.io/2018/01/02/%E6%A8%A1%E5%9E%8B%E5%8E%8B%E7%BC%A9%E8%AE%BA%E6%96%87%E9%98%85%E8%AF%BB%E8%AE%B0%E5%BD%95/"/>
<id>https://hjchen2.github.io/2018/01/02/%E6%A8%A1%E5%9E%8B%E5%8E%8B%E7%BC%A9%E8%AE%BA%E6%96%87%E9%98%85%E8%AF%BB%E8%AE%B0%E5%BD%95/</id>
<published>2018-01-02T14:00:04.000Z</published>
<updated>2023-02-07T02:40:11.362Z</updated>
<content type="html"><![CDATA[<h2id="regularization-of-neural-networks-using-dropconnect">Regularizationof Neural Networks using DropConnect</h2><ul><li>DropConnect主要是用来解决全连接过拟合问题的,是Dropout的通用实现。随着神经网络参数量越来越大,过拟合的风险越来越高,之前的一些经验是使用L1/L2以及Dropout。Dropout随机地将激活函数输出置0,导致每次参与训练的参数量变少,由于随机drop的关系,每次训练的网络都可能不一样,因此实际上我们训练的是多个子模型组成的混合模型。</li></ul><span id="more"></span><p><imgsrc="https://raw.githubusercontent.com/hjchen2/personal/master/blog/pruning/0.png" /></p><ul><li><p>Dropout</p><p>如果考虑激活函数为tanh和relu,则dropout的输出:</p><p><span class="math display">\[r=m*a(Wv)=a(m*(Wv))\]</span></p><p>inference时混合模型的输出:</p><p><span class="math inline">\(o=E_{M}[a(M*(Wv))] \approxa(E_{M}[(M*W)v])=a(pWv)\)</span></p><p><span class="math inline">\(M\)</span>是<spanclass="math inline">\(m\)</span>的repeat得到的矩阵。</p></li><li><p>DropConnect</p><p>随机地将全连接层的权重值置0,即输出为:</p><p><span class="math display">\[r=a((M*W)v)\]</span></p><p><span class="math inline">\(M\)</span>是与<spanclass="math inline">\(W\)</span>大小一致的0-1矩阵,并且<spanclass="math inline">\(M_{ij}\)</span>服从Bernoulli(p)分布。</p><p>inference时混合模型的输出:</p><p><span class="math display">\[o=E_{M}[a((M*W)v)] \approx E_{u}[a(u)]\]</span></p><p>where <span class="math inline">\(u\sim N(pWv,p(1-p)(W*W)(v*v))\)</span></p><p>注:对于<spanclass="math inline">\(u\)</span>的分布论文中提到用高斯矩匹配估计,但也可以用中心极限定理进行估计</p><p><imgsrc="https://raw.githubusercontent.com/hjchen2/personal/master/blog/pruning/3.png" /></p></li></ul><p>训练时的伪代码: <imgsrc="https://raw.githubusercontent.com/hjchen2/personal/master/blog/pruning/1.png" /></p><p>inference时的伪代码: <imgsrc="https://raw.githubusercontent.com/hjchen2/personal/master/blog/pruning/2.png" /></p><ul><li><p>实验结果</p><p><imgsrc="https://raw.githubusercontent.com/hjchen2/personal/master/blog/pruning/4.png" /></p></li><li><p>总结</p><p>DropConnect的初衷是解决过拟合问题的,DropConnect虽然在训练时可以将稠密矩阵乘转化成稀疏乘的方式,减少计算量,但在inference时还是需要完整的计算一遍,然后再利用正态分布多次采样后计算均值得到下一层的输入,因此inference的计算量反而增加了。论文给出的实验结果表明DropConnect在tanh和relu激活函数时会比dropout带来更低的测试错误率,sigmoid时会比dropout差点。DropConnect给模型压缩提供了一些思路,在训练时我们都倾向于选择更复杂的模型而需要非常大的计算量,DropConnect的做法表明这些复杂的模型实际上有大量的冗余,而去除这些冗余后并不会对模型产生任何伤害,反而会增强模型的泛化能力,因此在模型压缩中,对模型进行剪枝成了一个重要的研究方向。</p></li></ul><p>##Learning bothWeights and Connections for Efficient NeuralNetwork</p><ul><li><p>作者首先关注到神经网络预测时的能耗问题,下面给出了一个45nm的CMOS处理器能耗表。</p><p><imgsrc="https://raw.githubusercontent.com/hjchen2/personal/master/blog/pruning/5.png" /></p></li></ul><p>内存读取的能量消耗比其他数学指令高出三个数量级,因此论文提出对神经网络进行剪枝以压缩模型大小,减少内存读取消耗并降低计算量。剪枝不仅降低了模型复杂度,也减少了过拟合。除了剪枝,文中也提到可以借鉴HashedNets的方法进行模型参数共享,进一步降低模型大小。</p><p>模型剪枝分成三步:</p><p>1、正常训练模型,得到每个连接的重要程度(重要程度可以用权值的绝对值表示)</p><p>2、删除重要程度低的连接,将稠密网络转换成稀疏网络</p><p>3、使用保留下来的连接重训模型</p><p>第2步和第3步迭代进行。</p><p><imgsrc="https://raw.githubusercontent.com/hjchen2/personal/master/blog/pruning/6.png" /></p><ul><li><p>正则化</p><p>关于正则化对剪枝结果的影响,论文给出的结论是:剪枝后重训前L1正则比L2效果好,但重训后L2比L1效果好。</p></li><li><p>Dropout Ratio调整</p><p>Dropout仍然被用来抑制过拟合,但是由于剪枝会减小模型大小,因此重训时Dropoutratio也应该更小。</p><p><spanclass="math display">\[D_{r}=D_{0}\sqrt{\frac{C_{ir}}{C_{io}}}\]</span></p><p><span class="math display">\[C_{i}=N_{i}N_{i-1}\]</span></p><p>其中<span class="math inline">\(D_{r}\)</span>为重训的ratio,<spanclass="math inline">\(D_{0}\)</span>为原始的ratio,<spanclass="math inline">\(N_{i}\)</span>为第<spanclass="math inline">\(i\)</span>层的神经元个数。</p></li><li><p>重训参数</p><p>由于神经网络的连续层往往保持耦合性,因此重训模型时最好保持连接的权重,而不是重新初始化。并且卷积层和全连接层的剪枝是交替进行的,对fc进行剪枝重训时需要保持conv不变,反之对conv进行剪枝重训时需要保持fc不变。</p></li><li><p>迭代剪枝</p><p>迭代剪枝的方式可以最大程度的压缩模型大小。在不损失效果的前提下,相比单次剪枝,多次迭代的方式可以将AlexNet的压缩率从5X提高到9X。</p></li><li><p>裁剪神经元</p><p>每次剪枝可以将那些没有输入连接或没有输出连接的神经元移除。无输出的神经元对最终模型结果没有任何影响,因此移除也不会对模型效果产生影响,而那些没有输入连接的神经元由于梯度下降和正则化最终也会变成无输出的神经元。</p></li><li><p>实验结果</p><p>文中将裁剪门限设置为一个质量参数乘以这一层权重的标准差,并在LeNet、AlexNet和VGG-16上进行了相关实验,卷积层也可以跟全连接层一样使用相同的剪枝策略,重训模型时会有一次调整学习率的过程,比如LeNet重训时学习率会衰减到原来的1/10,AlexNet会衰减至原来的1/100。</p><p><imgsrc="https://raw.githubusercontent.com/hjchen2/personal/master/blog/pruning/1-2.png" /></p><p>AlexNet各层的压缩情况:<imgsrc="https://raw.githubusercontent.com/hjchen2/personal/master/blog/pruning/1-3.png" /></p><p>剪枝与其他模型压缩方法的对比:</p><p><imgsrc="https://raw.githubusercontent.com/hjchen2/personal/master/blog/pruning/2-2.png" /></p></li><li><p>模型保存</p><p>稀疏矩阵在保存时需要同时保存indices,比如按照CSR格式保存时,我们除了保存所有的非零元素外,还需要保存每个元素对应的列号以及每行第一个非零元素在所有元素中的位置。为了压缩保存indices带来的开销,文中提到使用相对indices代替绝对indices,全连接层可以使用5bit来表示相对indices,而卷积层也可以只使用8bit。</p></li><li><p>总结</p><p>由于卷积层本身就是稀疏连接,相比fc对剪枝更敏感,因此剪枝方法对于全连接层的压缩率更高。剪枝只能压缩模型大小,但inference时并不会带来预测速度提升。intel在16年提出另一个剪枝与嫁接相结合的方法<ahref="https://arxiv.org/pdf/1608.04493.pdf">Dynamic Network Surgery forEfficientDNNs</a>,进一步提高了剪枝方法的压缩率和重训收敛速度,此外2017年孙剑等提出了针对卷积层的<ahref="https://arxiv.org/pdf/1707.06168.pdf">ChannelPruning方法</a>,可以结合此处的剪枝方法,应该可以达到更好的压缩效果。</p></li></ul><p>##Channel Pruning for Accelerating Very Deep Neural Networks</p>]]></content>
<summary type="html"><h2
id="regularization-of-neural-networks-using-dropconnect">Regularization
of Neural Networks using DropConnect</h2>
<ul>
<li>DropConnect主要是用来解决全连接过拟合问题的,是Dropout的通用实现。随着神经网络参数量越来越大,过拟合的风险越来越高,之前的一些经验是使用L1/L2以及Dropout。Dropout随机地将激活函数输出置0,导致每次参与训练的参数量变少,由于随机drop的关系,每次训练的网络都可能不一样,因此实际上我们训练的是多个子模型组成的混合模型。</li>
</ul></summary>
<category term="model compression" scheme="https://hjchen2.github.io/categories/model-compression/"/>
<category term="pruning" scheme="https://hjchen2.github.io/tags/pruning/"/>
</entry>
<entry>
<title>NEURAL MACHINE TRANSLATION论文学习串讲</title>
<link href="https://hjchen2.github.io/2017/12/01/seq2seq%E4%B8%B2%E8%AE%B2/"/>
<id>https://hjchen2.github.io/2017/12/01/seq2seq%E4%B8%B2%E8%AE%B2/</id>
<published>2017-12-01T04:24:08.000Z</published>
<updated>2023-02-07T02:38:51.695Z</updated>
<content type="html"><![CDATA[<h2 id="seq2seq">seq2seq</h2><p>主要学习的是论文Neural machine translation by jointly learning toalign and translate (Dzmitry Bahdanau、Yoshua Bengio等,2016.05)和Neuralmachine translation (Minh-ThangLuong,2016.12)。</p><p>神经机器翻译的目的是将一门语言的文本序列翻译成另一门语言的文本序列,因此机器翻译的训练语料一般是源语言和目标语言组成的一对文本,也叫做平行语料(parallelcorpus)。我们通常将输入和输出都是序列的模型叫做seq2seq,seq2seq不仅应用在机器翻译领域,也用于当前热门的自动问答系统以及文本摘要的自动生成等领域。</p><span id="more"></span><h2 id="encoder-decoder">Encoder-Decoder</h2><p>2014年Dzmitry Bahdanau、Yoshua Bengio等人在论文Learning PhraseRepresentations using RNN Encoder–Decoder for Statistical MachineTranslation中首次提出将RNNEncoder-Decoder结构来计算双语短语对的条件概率,用于改进统计机器翻译的效果。Encoder-Decoder是由encoder和decoder两部分组成,encoder将输入序列编码成定长的语义向量,decoder将语义向量进行解码得到目标序列。</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/nmt/12c7a5370bc9da07193c0bd43c5b27cd.png?raw=true" width=500 align=center></p><p>在NMT中Encoder-Decoder试图直接对并行语料的条件概率<spanclass="math inline">\(P(Y|X)\)</span>进行建模,encoder输入的是一组向量序列<spanclass="math inline">\(X=(x_{1},…,x_{T_{x}})\)</span>,<spanclass="math inline">\(x_i\)</span>为词<spanclass="math inline">\(i\)</span>的one-hot编码向量,并将序列<spanclass="math inline">\(X\)</span>编码成语义向量<spanclass="math inline">\(c\)</span>,decoder输入语义向量<spanclass="math inline">\(c\)</span>,并逐个生成序列<spanclass="math inline">\(Y=(y_{1},…,y_{T_{y}})\)</span>,其中<spanclass="math inline">\(y_{i}\)</span>的生成与之前已经生成的词序列<spanclass="math inline">\(y_{1},…,y_{i-1}\)</span>有关。</p><p><span class="math display">\[\log p(Y|X)=\sum_{t=1}^{T_{y}}\logp(y_{t}|y_{<t}, c)\]</span></p><p>对于不定长度序列的编码和解码,我们很自然会想到RNN,实际上RNNEncoder–Decoder就是正反两组RNN拼接在一起组成的编码解码网络。经典的RNNEncoder–Decoder示意图如下:</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/nmt/ab3551f2c0f12a3bc50283e49e09e52c.png?raw=true" width=400 align=center /></p><p>我们可以用下面公式描述编码过程: <spanclass="math display">\[h_{t}=f(x_{t},h_{t-1})\]</span> <spanclass="math display">\[c=q({h_{1},…,h_{T_{x}}})\]</span></p><p>函数<spanclass="math inline">\(f\)</span>一般用一个RNN结构来表示,可以是LSTM、GRU等,<spanclass="math inline">\(h_{t}\)</span>表示encoderRNN在第t时刻的cell隐状态,向量c的计算与encoderRNN所有时刻的cell隐状态相关,函数<spanclass="math inline">\(q\)</span>可以表示所有隐状态的加权和,但由于RNN的特殊性,我们这里只使用最后一个时刻的隐状态作为向量<spanclass="math inline">\(c\)</span>,即<spanclass="math inline">\(c=h_{T_{x}}\)</span>。</p><p>对于解码过程,生成<spanclass="math inline">\(y_{t}\)</span>时的条件概率可以改写成</p><p><spanclass="math display">\[p(y_{t}|y_{<t},c)=g(y_{t-1},s_{t},c)\]</span><span class="math display">\[s_{t}=f(s_{t-1},y_{t-1},c)\]</span></p><p>其中,<spanclass="math inline">\(g\)</span>是非线性函数,可以是单层的softmax,也可以是一个多层结构的神经网络,<spanclass="math inline">\(y_{t-1}\)</span>表示上一时刻的输出,<spanclass="math inline">\(f\)</span>同样是一个RNN结构,<spanclass="math inline">\(s_{t}\)</span>表示decoder RNN cell的隐状态。</p><h2 id="attention">Attention</h2><p>在Encoder-Decoder中每个目标词生成时使用的都是同一个向量<spanclass="math inline">\(c\)</span>,虽然理论上来讲向量<spanclass="math inline">\(c\)</span>可以表示输入序列的语义信息,比如一些关键词、句子结构和语法信息等,但也存在注意力分散的问题。在机器翻译中,一般翻译出来的词与源序列的词是有对齐关系的,也就是说目标词的生成与源序列中的部分关键词关系更大,而其他词对当前目标词的生成影响就很小。在Encoder-Decoder中不论生成哪个目标词,使用的语义向量都是<spanclass="math inline">\(c\)</span>,而语义向量<spanclass="math inline">\(c\)</span>是由句子<spanclass="math inline">\(X\)</span>的每个单词经过Encoder编码而成的,也就意味着句子<spanclass="math inline">\(X\)</span>中的关键词对生成任意目标词的影响力是相同的。</p><p><img src="https://coding.net/u/hjchen2/p/personal/git/raw/master/blog/pictures/v2-db380a8bf032afa9533d358389de99d6_hd.jpg?raw=true" width=500></p><p>第一篇论文在Encoder-Decoder的基础上引入注意力机制,来解决上述注意力分散的问题。在论文中提出,每个目标词生成时使用的语义向量是不同的,也就是说Encoder-Decoder将会学会在生成目标词时给每个源语词分配权重,这个权重表示该源语词对当前目标词的重要程度。增加了attention机制的Encoder-Decoder框架如下图:</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/nmt/e9ba93ee15054825cb2c66a7180ef022.png?raw=true" width=400 align=center></p><p>在基于attention的模型中,每个目标词生成时的条件概率可以写成: <spanclass="math display">\[p(y_{i}|y_{<t},X)=g(y_{i-1},s_{i},c_{i})\]</span><span class="math display">\[s_{i}=f(s_{i-1},y_{i-1},c_{i})\]</span></p><p>在RNN中每个时刻的隐状态<spanclass="math inline">\(h_{i}\)</span>可以表示第<spanclass="math inline">\(i\)</span>个源语词及其周围部分词的信息,因此与之前的Encoder-Decoder框架不同,语义向量<spanclass="math inline">\(c_{i}\)</span>不再是encoderRNN最后一个时刻的隐状态,而是与encoder RNN所有时刻的隐状态(<spanclass="math inline">\(h_{1},...,h_{T_{x}}\)</span>)相关的一个向量。</p><p><spanclass="math display">\[c_{i}=\sum_{j=1}^{T_{x}}\alpha_{ij}h_{j}\]</span><span class="math inline">\(\alpha_{ij}\)</span>可以认为是目标词<spanclass="math inline">\(i\)</span>与源语词<spanclass="math inline">\(j\)</span>的对齐权重,因此可以使用源语词<spanclass="math inline">\(i\)</span>的隐状态<spanclass="math inline">\(h_{i}\)</span>和目标词前一时刻的隐状态<spanclass="math inline">\(s_{i-1}\)</span>来计算。 <spanclass="math display">\[\alpha_{ij}=\frac{\exp(e_{ij})}{\sum_{k=1}^{T_{x}}\exp(e_{ik})}\]</span>其中 <span class="math display">\[e_{ij}=a(s_{i-1},h_{j})\]</span> <spanclass="math inline">\(a\)</span>是一个对齐模型,在Bahdanau的论文中将其定义成一个前馈神经网络,与Encoder-Decoder一起参与训练。计算公式如下:<span class="math display">\[a(s_{i-1},h_{j})=v_{a}^\mathsf{T}\cdottanh(W_{a}s_{i-1}+U_{a}h_{j}) \]</span> <spanclass="math inline">\(v_{a}\)</span>、<spanclass="math inline">\(W_{a}\)</span>和<spanclass="math inline">\(U_{a}\)</span>都是对齐模型的参数。在第二篇ThangLuong的论文中提出下面三种计算方式,本质上也是大同小异。</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/nmt/667d0e7417d384138f961490ff0745c3.png?raw=true" width=400 align=center></p><p>下图是Bahdanau在论文中给出的一个模拟图,图中模拟的是在给定源语序列(<spanclass="math inline">\(X_{1},X_{2},...,X_{T}\)</span>)的情况下生成第<spanclass="math inline">\(t\)</span>个目标词<spanclass="math inline">\(y_{t}\)</span>的过程。</p><p><img src="https://github.com/hjchen2/personal/blob/master/blog/nmt/970f70807791925f3f8f54266e0a8435.png?raw=true" width=300 align=center></p><h2 id="encoder">Encoder</h2><p>在Bahdanau的论文中Encoder和Decoder使用的都是GRU(Gated RecurrentUnit),GRU与LSTM一样都是RNN众多变体中比较常见的一种,也可以使用其他变体RNN,比如在ThangLuong的论文中主要用的就是LSTM。</p><p>我们知道传统的RNN理论上可以记忆无限长的序列,但由于递归权重对每个时刻的输入都是一样的,这就导致一个二选一的问题:(1)模型发散,无法收敛(2)梯度消失,无法产生长时记忆。GRU和LSTM一样,都是通过引入门(gate)的机制来解决传统RNN梯度消失的问题,gate打开和关闭是由当前时刻的输入和前一时刻的隐层状态控制的,也就是说每个时刻gate的状态都是不同的,一些需要长时间记忆的信息会通过gate一直传递下去,从而学习到长距离依赖。</p><p>传统RNN的隐层计算公式:<spanclass="math inline">\(h_{t}=g(W^{hh}h_{t-1}+W^{hx}x_{t})\)</span>,<spanclass="math inline">\(W^{hh}\)</span>是递归权重,<spanclass="math inline">\(W^{hx}\)</span>是隐层的权重。实际上,LSTM和GRU都可以认为是对<spanclass="math inline">\(h_{t}\)</span>计算方式的改进。</p><p>下面是GRU结构的示意图,输入为<spanclass="math inline">\(h_{t-1}\)</span>和<spanclass="math inline">\(x_{t}\)</span>,输出为<spanclass="math inline">\(h_{t}\)</span>。在GRU中存在两个gate,一个是resetgate,一个是update gate,分别对应下图中的<spanclass="math inline">\(r_{t}\)</span>和<spanclass="math inline">\(z_{t}\)</span>,<spanclass="math inline">\(\widetildeh_{t}\)</span>表示候选隐层状态,候选隐层状态与上一时刻的隐层状态<spanclass="math inline">\(h_{t-1}\)</span>一起更新当前时刻的隐层状态<spanclass="math inline">\(h_{t}\)</span>。</p><p><img src="https://coding.net/u/hjchen2/p/personal/git/raw/master/blog/pictures/rnn-gru-unit.png?raw=true" width=400 align=center></p><p>GRU的计算过程:<br />1、首先计算重置门<span class="math inline">\(r_{t}\)</span>和更新门<spanclass="math inline">\(z_{t}\)</span>,其中<spanclass="math inline">\(\sigma\)</span>表示sigmoid函数 <spanclass="math display">\[r_{t}=\sigma(W^{r}x_{t}+U^{r}h_{t-1})\]</span><spanclass="math display">\[z_{t}=\sigma(W^{z}x_{t}+U^{z}h_{t-1})\]</span>2、计算候选隐层状态<span class="math inline">\(\widetildeh_{t}\)</span>,其中<spanclass="math inline">\(r_{t}\)</span>用来控制历史记忆的传递,如果<spanclass="math inline">\(r_{t}=0\)</span>,那么<spanclass="math inline">\(\widetilde h_{t}\)</span>只与当前输入<spanclass="math inline">\(x_{t}\)</span>有关,历史记忆被重置。 <spanclass="math display">\[\widetilde h_{t}=tanh(Wx_{t}+U[r_{t}\odoth_{t-1}])\]</span> 实际上仅仅增加一个resetgate就已经可以解决长时依赖的问题,因为如果有需要<spanclass="math inline">\(r_{t}\)</span>可以总等于1,那么历史记忆就会一直传递下去。但这会带来一个问题,<spanclass="math inline">\(h_{t-1}\)</span>会累加到当前时刻的隐层状态上产生新的记忆,不断累加的记忆会导致<spanclass="math inline">\(\widetildeh_{t}\)</span>达到饱和,最终导致模型无法收敛。为了解决这个问题,GRU可以选择对当前输入产生的新记忆进行遗忘,只传递之前的历史记忆,也就是说我们允许GRU舍弃一些对后续无关的输入信息,保证记忆都是有效信息。GRU是通过下面的更新操作来实现这个过程的,<span class="math display">\[h_{t}=z_{t}\odot h_{t-1}+(1-z_{t})\odot\widetilde h_{t}\]</span> <spanclass="math inline">\(z_{i}\)</span>反映了相对历史记忆当前输入信息的重要程度,<spanclass="math inline">\(z_{i}\)</span>越小表明当前输入信息越重要。</p><p>实际上在Bahdanau的论文中使用的是双向RNN(BiRNN),BiRNN在前向RNN的基础上增加了一个反向RNN,使得RNN可以同时看到历史和未来的信息,最终前向RNN的隐层状态和反向RNN的隐层状态拼接后输出。</p><p><span class="math display">\[h_{i}=\left [ \begin{align} &\vec{h_{i}} \\ & \stackrel{\leftarrow}{h_{i}} \end{align}\right]\]</span></p><h2 id="decoder">Decoder</h2><p>在Bahdanau的论文中decoder采用是一个前向的GRU,但与encoderGRU不同的是decoder GRU需要额外输入语义向量<spanclass="math inline">\(c_{i}\)</span>。decoder GRU隐层状态<spanclass="math inline">\(s_{i}\)</span>的计算如下: <spanclass="math display">\[s_{i}=(1-z_{i})\odot s_{i-1}+z_{i}\odot\widetilde s_{i}\]</span> 其中,<br /><span class="math display">\[\widetilde s_{i}=tanh(Wy_{i-1}+U[r_{i}\odots_{i-1}]+Cc_{i})\]</span> <spanclass="math display">\[r_{i}=\sigma(W_{r}y_{i-1}+U_{r}s_{i-1}+C_{r}c_{i})\]</span><spanclass="math display">\[z_{i}=\sigma(W_{z}y_{i-1}+U_{z}s_{i-1}+C_{z}c_{i})\]</span>encoder GRU的隐层状态会被传递到decoderGRU用于生成第一个目标词,所以decoderGRU的隐层状态的初始值不是0,而是将encoder中反向GRU第一个时刻的隐层状态直接复制给decoderGRU,即<spanclass="math inline">\(s_{0}=tanh(W_{s}\stackrel{\leftarrow}{h_{1}})\)</span>。</p><h2 id="beam-search">beam search</h2>]]></content>
<summary type="html"><h2 id="seq2seq">seq2seq</h2>
<p>主要学习的是论文Neural machine translation by jointly learning to
align and translate (Dzmitry Bahdanau、Yoshua Bengio等,2016.05)和Neural
machine translation (Minh-ThangLuong,2016.12)。</p>
<p>神经机器翻译的目的是将一门语言的文本序列翻译成另一门语言的文本序列,因此机器翻译的训练语料一般是源语言和目标语言组成的一对文本,也叫做平行语料(parallel
corpus)。我们通常将输入和输出都是序列的模型叫做seq2seq,seq2seq不仅应用在机器翻译领域,也用于当前热门的自动问答系统以及文本摘要的自动生成等领域。</p></summary>
<category term="neural machine translation" scheme="https://hjchen2.github.io/categories/neural-machine-translation/"/>
<category term="seq2seq" scheme="https://hjchen2.github.io/tags/seq2seq/"/>
<category term="machine translation" scheme="https://hjchen2.github.io/tags/machine-translation/"/>
<category term="Encoder-Decoder" scheme="https://hjchen2.github.io/tags/Encoder-Decoder/"/>
<category term="Attention" scheme="https://hjchen2.github.io/tags/Attention/"/>
</entry>
<entry>
<title>阿里KunPeng框架学习</title>
<link href="https://hjchen2.github.io/2017/08/22/KunPeng%E8%AE%BA%E6%96%87%E9%98%85%E8%AF%BB/"/>
<id>https://hjchen2.github.io/2017/08/22/KunPeng%E8%AE%BA%E6%96%87%E9%98%85%E8%AF%BB/</id>
<published>2017-08-22T04:53:08.000Z</published>
<updated>2023-02-07T02:49:01.177Z</updated>
<content type="html"><![CDATA[<p>KunPeng是阿里最新公布的一个大规模机器学习框架,不仅包括了数据/模型并行、负载均衡、模型同步、稀疏表达、工业级容错等特性,而且还提供了易于使用的接口,在很多机器学习算法上都带来了非常大的性能提升。原始论文 KunPeng: Parameter Server based Distributed Learning Systemsand Its Applications in Alibaba and Ant Financial。</p><span id="more"></span><h2 id="introduction">Introduction</h2><p>主要对一些通用分布式计算框架进行比较。</p><p>Hadoop:只提供了一些粗粒度的操作,比如Map、Reduce和Join等。很多限制导致基于Hadoop的机器学习算法效率都非常低,这些限制包括中间结果会落盘、只能在shuffling阶段进行数据交换等。</p><p>Spark:使用RDD弥补了Hadoop的一些缺点,提供MLlib库,MLlib整合了很多机器学习算法,并且非常容易使用。但MLlib只支持中等规模的特征,计算和通信效率都比较低。一些公司使用第三方组件来弥补Spark的缺陷,但至今没有一个完美的方案。</p><p>GraphLab和GraphX:基于图的并行计算框架,允许用户进行细粒度的控制,但并不适合通用的机器学习算法,比如LR、深度学习等,并且也存在效率低的问题。</p><p>MPI:接口灵活高效,代码自由度比较高,比如在代码中所有进程之间可以随时通信。但使用MPI开发一个新算法的开销非常大,比如一个复杂的异步矩阵分解算法需要2000多行代码。MPI没有提供分布式ML平台通用的组件,比如分布式数据读取,内存管理和多线程并行的组件。更重要的是MPI没有提供单点失败的本地解决方案,根据他们的统计数据显示MPI作业在节点数越多时失败率越高。</p><p>parameterserver框架:包含无状态的workers和有状态的servers,workers负责大部分的计算任务,servers负责保存和更新模型参数。servers可以定期将模型参数快照保存到一个缓存位置,一旦有节点失败,parameterserver会自动从最新的checkpoint中恢复模型参数。parameterserver框架只支持pserver和worker之间通信,而pserver和pserver、worker和worker之间无法进行点对点通信,并且由于细粒度的接口导致用户编程比较复杂,因此现有的parameterserver框架还存在几个问题:一是通信接口比较单一,没有MPI灵活;二是对于用户来说没有Spark易于编程使用。</p><p>正是由于上述框架的种种缺点,他们开发了一个产品级的分布式学习系统—KunPeng。KunPeng结合了parameterserver和MPI的优点,提供鲁棒的failover机制,高效的稀疏数据通信接口和与MPI类似的通用接口,并且提供一个C++和Python的SDK,该SDK提供了一个类似单机的开发环境。KunPeng也与阿里的Apsara平台深度对接,提供ML的全工具集,包括基于SQL和MapReduce的数据预处理、预测、评估等等。</p><h2 id="kunpeng整体架构">KunPeng整体架构</h2><h3 id="apsara-cloud-platform">Apsara Cloud Platform</h3><p>Apsara是阿里开发的一个大规模分布式操作系统,目前已运行在跨数十个机房的十几万台服务器上。下图中天蓝色部分就是Apsara的模块,白色部分为运行在Apsara之上的各种云服务,KunPeng就属于图中白色部分,运行在Apsara上,由Apsara提供任务调度和监控、文件系统等服务。<imgsrc="https://github.com/hjchen2/personal/blob/master/blog/KunPeng/b2b0cb8a6973ec2b4281d68c328e4a0f.png?raw=true"alt="b2b0cb8a6973ec2b4281d68c328e4a0f" />图中红色边框的任务调度模块和资源管理模块被统称为Fuxi(伏羲),Fuxi支持多种特性以保证系统的可扩展性和容错性,这些特性包括:增量资源管理协议、用户透明的失败恢复、故障点自动检测和多级黑名单机制。</p><h2 id="kunpeng-platform">KunPeng Platform</h2><p>KunPeng分为ML-Bridge和PS-Core两个子系统,ML-Bridge是KunPeng提供的高级编程模型,用户通过脚本编程的workflow可以方便地实现数据预处理、训练、预测和评估等算法,PS-Core是一个分布式键值对存储的paramterserver框架。 <imgsrc="https://github.com/hjchen2/personal/blob/master/blog/KunPeng/0313b564c3646a4c4fab16574f9c4b4e.png?raw=true%20=600"alt="0313b564c3646a4c4fab16574f9c4b4e" /> ML-Bridge由三个组件构成:</p><ul><li>解释器。将用户的脚本解释为系统支持的算法</li><li>优化器。根据运行状态的历史统计和启发式方法,分析、调试和优化作业配置</li><li>执行器。根据作业的配置生成Fuxi调度的配置,提供整个作业生命周期的监控,并提供用户监控UIML-Bridge简化了用户编程,比如一个算法流程包括数据入库与预处理、训练、评估和AB测试几个流程,在KunPeng中只需要调用下图中的几行命令就可以实现。整个流程对用户来说都是透明的,用户也不需要关心算法的具体实现和作业调度过程。</li></ul><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/KunPeng/ede2df215585fc86358bc9868565d1ce.png?raw=true%20=500"alt="ede2df215585fc86358bc9868565d1ce" /><figcaptionaria-hidden="true">ede2df215585fc86358bc9868565d1ce</figcaption></figure><p>PS-Core不仅支持数据并行和模型并行,同时还支持模型同步更新(BSP)、ASP和SSP,稀疏表达和容错机制。PS-Core在传统的worker和server基础上,增加了一个用于迭代控制的coordinator。coordinator声明了数据计算和参数更新的操作,构建了整个MLworkerflows的作业图,并将这些作业调度到worker和server上运行,并参与servers和workers的failover过程。coordinator在迭代结束时会与Apsara的meta对迭代状态进行同步,并且由Fuxi监控管理,因此不存在SPOF(单点失败)的问题。</p><h3 id="容错方案">容错方案</h3><p>KunPeng也给出了servers和workers的容错解决方案。对于servers,它们会异步地将参数快照保存到分布式文件系统,并且它们会在内存中对参数进行两备份,支持hotfailover加速恢复过程。大多数情况下(比如接收到coordinator的恢复请求),servers可以立刻通过内存备份的参数中恢复。即使是servers或整个任务被中断或被kill,servers也可以通过最近一次保存的参数进行恢复训练。对于stateless的workers,failover非常简单,只需要从servers上pull对应的参数。对于stateful的workers,同样提供保存快照的接口,因此对于一些workers有本地状态的算法(比如LDA),faliover也非常简单。</p><p>总的来说,KunPeng的failover过程是当Fuxi检测到有节点失败时,重新调度新的节点,同时给coordinator发送异步节点失败的消息,coordinator接收消息后给servers和workers发送恢复请求,对于正常的servers接收请求后会直接从内存中恢复,而对于新调度的servers会从checkpoint中恢复,对于workers需要先从servers上pull对应的参数,stateful的workers还需要从保存的checkpoint中恢复状态。</p><h3 id="dag调度">DAG调度</h3><p>这里的调度指的是coordinator对servers和workers的调度。由于coordinator节点会根据算法的workflow构建对应的作业DAG,并将DAG调度到servers和workers上进行执行。为了提高机器资源利用率和作业效率,DAG中相同深度的节点可以并行执行,比如下图中的Calculatefor Block 0节点和Load Data for Block1节点。通过DAG接口用户可以自定义IO操作、计算和通信过程,可以很方便地实现各种模型更新算法。</p><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/KunPeng/e76cf7c13015b83ed7696b5fa7c8dac0.png?raw=true%20=600"alt="e76cf7c13015b83ed7696b5fa7c8dac0" /><figcaptionaria-hidden="true">e76cf7c13015b83ed7696b5fa7c8dac0</figcaption></figure><p>下图表示了PS-Core中bounded delayASGD算法的C++实现,用户可以重写下面的Iterate函数实现自定义的算法。图中的mServerParam和mServerGrad对应servers上的模型参数和梯度,mWorkerParam和mWorkerGrad对应workers本地的模型参数和梯度,mSubDatasetPtr对应当前worker的数据子集。nSync为最大延迟迭代次数,nPull和nPush分别为从servers获取最新参数和将梯度发送给servers的频率。通过设置nSync、nPull和nPush可以很方便地在BSP和SSP之间切换,而去除SyncBarrier就成了ASP算法的实现。</p><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/KunPeng/69ed0d3573fbebf558494bc4a9a14c74.png?raw=true%20=450"alt="69ed0d3573fbebf558494bc4a9a14c74" /><figcaptionaria-hidden="true">69ed0d3573fbebf558494bc4a9a14c74</figcaption></figure><h3 id="负载均衡和通信接口">负载均衡和通信接口</h3><p>由于集群中机器的底层硬件和运行状态存在差异,因此一个任务的执行效率很大程度上取决于运行最慢的那个机器,针对这种情况可以有多种负载均衡的方法,比如可以对负载较高的机器分配更少的数据和计算量,PS-Core也为此设计了一个Backupinstance机制。当某个节点被确定为慢节点时,coordinator会把慢节点标记为"dead"节点,请求Fuxi重新调度一个新的节点作为该节点的备份节点,并将该节点的负载转移到备份节点上。这种机制通常可以带来10%-20%的效率提升。</p><p>KunPeng对不同稀疏度和不同数据类型的数据通信做了深度优化,并且提供workers之间点对点的通信接口,比如AllReduce,ReduceTo和Bcast,这些灵活的通信接口使得KunPeng可以拓展更多的功能,比如模型并行。</p><h2 id="ftrl">FTRL</h2><p><spanclass="math display">\[w_{t+1}=\mathop{\arg\min}_{w}\left(\sum_{s=1}^{t}g_{s}w+\frac{1}{2}\sum_{s=1}^{t}\delta_{s}{\Vert}w-w_{s}{\Vert}_{2}^{2}+\lambda_{1}{\Vert}w{\Vert}_{1}+\lambda_{2}{\Vert}w{\Vert}_{2}^{2}\right)\]</span>其中<span class="math inline">\(g\)</span>为损失函数对<spanclass="math inline">\(w\)</span>的梯度,<spanclass="math inline">\(\delta_{t}=\frac{1}{\eta_{t}}-\frac{1}{\eta_{t-1}}\)</span>,因此<spanclass="math inline">\(\sum_{s=1}^{t}{\delta_{s}}=\frac{1}{\eta_{t}}\)</span>,<spanclass="math inline">\(\eta\)</span>为学习率,并且<spanclass="math inline">\(\eta_{t,i}=\frac{\alpha}{\beta+\sqrt{\sum_{s=1}^{s}{g_{s,i}^2}}}\)</span>,通常<spanclass="math inline">\(\alpha=1\)</span>,<spanclass="math inline">\(\beta\)</span>是与数据集和特征相关的超参数。<spanclass="math inline">\(\lambda_{1}\)</span>为L1系数,<spanclass="math inline">\(\lambda_{2}\)</span>为L2系数。 更新公式为<br /><span class="math display">\[w_{t+1}=\begin{cases}0& if\{\vert}z_{i}{\vert}{\leq}\lambda_{1}\\-(\frac{\beta+\sqrt{n_{i}}}{\alpha}+\lambda_{2})^{-1}(z_{i}-sign(z_{i})\lambda_{1})&otherwise\end{cases}\]</span> 下图表明了LRFTRL-Proximal算法单机更新过程。</p><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/KunPeng/66cf72a181547ae24831af8500b47d72.png?raw=true%20=500"alt="66cf72a181547ae24831af8500b47d72" /><figcaptionaria-hidden="true">66cf72a181547ae24831af8500b47d72</figcaption></figure><p>这个算法在单机时很容易实现,但在分布式环境必须要考虑通信效率、servers的负载和算法收敛性问题。考虑到BSP的低效和ASP可能不收敛的问题,他们使用了boundeddelay的SSP更新方法,并且设置trustregion来调节参数范围,避免模型发散。整个算法具体过程如下:</p><ul><li>workers本地保存了模型<span class="math inline">\(w\)</span>和<spanclass="math inline">\(z\)</span>、<spanclass="math inline">\(n\)</span>,<spanclass="math inline">\(z\)</span>、<spanclass="math inline">\(n\)</span>通过bounded-asynchronous的方式与servers保持同步</li><li>workers加载数据,根据<span class="math inline">\(z\)</span>和<spanclass="math inline">\(n\)</span>更新本地模型<spanclass="math inline">\(w\)</span>,计算梯度并更新本地模型<spanclass="math inline">\(w\)</span>和<spanclass="math inline">\(z\)</span>、<spanclass="math inline">\(n\)</span>,同时使用<spanclass="math inline">\(\delta_{z}\)</span>和<spanclass="math inline">\(\delta_{n}\)</span>累加<spanclass="math inline">\(z\)</span>和<spanclass="math inline">\(n\)</span>的增量,在需要与servers同步的时候将累加的<spanclass="math inline">\(\delta_{z}\)</span>和<spanclass="math inline">\(\delta_{n}\)</span> push到servers</li><li>servers合并所有workers发送的<spanclass="math inline">\(\delta_{z}\)</span>和<spanclass="math inline">\(\delta_{n}\)</span>,最后更新全局<spanclass="math inline">\(z\)</span>和<spanclass="math inline">\(n\)</span>。</li></ul><p>workers向servers传递<span class="math inline">\(z\)</span>和<spanclass="math inline">\(n\)</span>的增量,而不是直接传递模型梯度,这种做法虽然会带来一些通信开销,但降低了servers的计算负载,这是在通信效率和计算负载之间做的平衡。为了避免发散,servers在trustregion下更新模型。trustregion的策略有两种:一种是当模型中的元素超出置信阈时,直接回退整个模型;另一种是通过映射的方式将模型的值限制在置信阈中。</p><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/KunPeng/0de2241d38a792bb79446944d65d8c66.png?raw=true%20=600"alt="0de2241d38a792bb79446944d65d8c66" /><figcaptionaria-hidden="true">0de2241d38a792bb79446944d65d8c66</figcaption></figure><h2 id="mart">MART</h2><p>MART(多增量回归树)又叫做GBDT,是一种应用比较广泛的机器学习算法。KunPeng实现了一个通用的MART算法,支持千亿级样本量和上千维的特征,并在MART的基础上实现了LambdaMART算法。</p><ul><li>MART为了处理超大规模的数据量,KunPeng-MART使用数据并行的方式减少内存使用量,并采用了XGBoost的分布式加权直方图算法优化分裂点查找过程。具体来说就是,每个worker都保存了整颗树,在分割叶节点时,(1)每个worker使用分配的数据子集计算一个局部加权直方图,计算完成后将直方图push到servers(2)servers收到workers发送的直方图后,采用多路合并算法得到全局直方图,并找到最优分割点(3)workers从serverspull分割点,分裂节点并将数据分到分裂后的叶节点</li></ul><p>重复上述过程,可以得到整棵树。然后只要按照gradientboosting方法一棵一棵地建树,最终得到MART。随着特征维度和树深度的增加,查找分裂点过程中的计算和通信都可能成为性能瓶颈。为了解决这个问题,他们提到使用KunPeng的通信模式去减少合并局部直方图的开销,但并没有透露具体的方法。</p><ul><li>LambdaMARTLambdaMART建树的过程与上面的MART一样,不同的是LambdaMART计算一阶导数和二阶导数的方式。由于LambdaMART要求同一个querygroup的训练数据按sample两两组成pair对,因此当训练数据不是按照querygroup连续存储时就会存在问题。对于这个问题,他们提出了两种解决方法:<br />(1)先全局统计一下每个query id对应的样本总数,然后按照multiway numberpartitioning algorithm对queryid进行分片,每个worker只加载属于自己的query ids对应的训练样本。<br />(2)第二种是近似的方法。首先要求相同queryid的样本在文件系统中是连续存储的,然后每个worker还是按照正常情况加载属于自己的分片数据。如果相同queryid的样本被分在两个不同的worker上,则会把这两个worker上相同queryid的样本当做不同query id来处理。</li></ul><h2 id="其他算法">其他算法</h2><ul><li>Large-scale sparse Logistic Regression (LR)<br />实现了不同的优化算法,L-BFGS、OWL-QN和BCD,其中BCD算法是数据和模型同时并行的算法。<br /></li><li>Distributed Factorization Machines<br />workers异步计算梯度,使用AdaGrad优化算法<br /></li><li>Caffe<br />实现了Caffe和KunPeng的对接,a generalized CPU-based large-scale deeplearning platform,简化DL算法开发</li></ul><h2 id="实验结果">实验结果</h2><p>下面的实验都是在一个拥有5000台服务器的正式集群上进行的,每台机器12个IntelXeon CPU E5-2430 (2.2 GHz) CPU和96GB内存。</p><h3id="kunpengspark和mpi的lr算法对比">KunPeng、Spark和MPI的LR算法对比</h3><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/KunPeng/143e082b7f1a6b54e47e9c8b51026dbb.png?raw=true"alt="143e082b7f1a6b54e47e9c8b51026dbb" /><figcaptionaria-hidden="true">143e082b7f1a6b54e47e9c8b51026dbb</figcaption></figure><p>不同平台的LR都采用L-BFGS算法更新,并且memory historyparameter都设置为10,并且使用同一个集群相同的CPU资源,在7个不同的数据集上KunPeng在效率和内存占用上都取得非常明显的优势。</p><p>在另外一个18 billion样本和 7billion特征的数据集上,他们统计了KunPeng在不同workers数量时的加速比。</p><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/KunPeng/00c84f368394ba04d59dbe530f69c387.png?raw=true"alt="00c84f368394ba04d59dbe530f69c387" /><figcaptionaria-hidden="true">00c84f368394ba04d59dbe530f69c387</figcaption></figure><p>KunPeng仅使用25个workers就可以训练这么大的数据,workers增加时依然能保持较高的加速比,并且内存占用随着workers增加而近乎直线降低。</p><h3 id="kunpeng-mart和xgboost的对比">KunPeng-MART和XGBoost的对比</h3><p>下图分别为KunPeng-MAR和XGBoost在不同任务上的峰值内存占用和训练时间对比。</p><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/KunPeng/1b0888cab293242eaccdc2b6e5bf25d9.png?raw=true%20=500"alt="1b0888cab293242eaccdc2b6e5bf25d9" /><figcaptionaria-hidden="true">1b0888cab293242eaccdc2b6e5bf25d9</figcaption></figure><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/KunPeng/3b99dc82bc268d3da394a688c0234908.png?raw=true%20=500"alt="3b99dc82bc268d3da394a688c0234908" /><figcaptionaria-hidden="true">3b99dc82bc268d3da394a688c0234908</figcaption></figure><h3id="kunpeng-fmlibfm和difacto的对比">KunPeng-FM、LibFM和DiFacto的对比</h3><p>下面是在单机情况下的训练效果对比,并没有训练时间的对比数据和多机实验相关的数据。</p><figure><imgsrc="https://github.com/hjchen2/personal/blob/master/blog/KunPeng/da511a1bb0db987fb74ebb08fa5352c9.png?raw=true%20=500"alt="da511a1bb0db987fb74ebb08fa5352c9" /><figcaptionaria-hidden="true">da511a1bb0db987fb74ebb08fa5352c9</figcaption></figure><h2 id="参考资料">参考资料</h2><p>1、Ad Click Prediction: a View from the Trenches.</p>]]></content>
<summary type="html"><p>KunPeng是阿里最新公布的一个大规模机器学习框架,不仅包括了数据/模型并行、负载均衡、模型同步、稀疏表达、工业级容错等特性,而且还提供了易于使用的接口,在很多机器学习算法上都带来了非常大的性能提升。
原始论文 KunPeng: Parameter Server based Distributed Learning Systems
and Its Applications in Alibaba and Ant Financial。</p></summary>
<category term="ML framework" scheme="https://hjchen2.github.io/categories/ML-framework/"/>
<category term="large scale ML framework" scheme="https://hjchen2.github.io/tags/large-scale-ML-framework/"/>
<category term="KunPeng" scheme="https://hjchen2.github.io/tags/KunPeng/"/>
</entry>
<entry>
<title>C++调用python</title>
<link href="https://hjchen2.github.io/2017/07/03/C++%E8%B0%83%E7%94%A8Python%E6%8E%A5%E5%8F%A3/"/>
<id>https://hjchen2.github.io/2017/07/03/C++%E8%B0%83%E7%94%A8Python%E6%8E%A5%E5%8F%A3/</id>
<published>2017-07-03T04:31:08.000Z</published>
<updated>2023-01-03T14:04:13.435Z</updated>
<content type="html"><![CDATA[<p>由于需要在组内新开发的一套机器学习框架上开发一个强化学习的demo,但目前开源的一些游戏环境都只提供了python接口,比如Gym。如果要使用Gym去做在线训练的话,就需要在C++代码中调用Python接口,因此找了些例子学习了一下如何使用PythonC API。当然Python CAPI不是唯一的方式,也可以使用boost的Python模块,有时间再研究。</p><span id="more"></span><h2 id="hello-python">hello python</h2><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><stdio.h></span></span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><iostream></span></span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string">"python/Python.h"</span></span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="type">int</span> <span class="title">main</span><span class="params">()</span> </span>{</span><br><span class="line"> <span class="built_in">Py_Initialize</span>();</span><br><span class="line"> std::cout << <span class="string">"hello c++!"</span> << std::endl;</span><br><span class="line"> <span class="built_in">PyRun_SimpleString</span>(<span class="string">"print 'hello python!'"</span>);</span><br><span class="line"> <span class="built_in">Py_Finalize</span>();</span><br><span class="line"> <span class="keyword">return</span> <span class="number">0</span>;</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>编译:</p><figure class="highlight txt"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">g++ test.cpp -o test -lpython</span><br></pre></td></tr></table></figure><p>执行:./test</p><figure class="highlight txt"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">hello c++!</span><br><span class="line">hello python!</span><br></pre></td></tr></table></figure><h2 id="调用python脚本中的函数">调用python脚本中的函数</h2><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># test_add.py</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">add</span>(<span class="params">a, b</span>):</span><br><span class="line"> <span class="keyword">return</span> a+b</span><br></pre></td></tr></table></figure><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><stdio.h></span></span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><iostream></span></span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string">"python/Python.h"</span></span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="type">int</span> <span class="title">main</span><span class="params">(<span class="type">int</span> argc, <span class="type">char</span>* argv[])</span> </span>{</span><br><span class="line"> <span class="keyword">if</span> (argc < <span class="number">3</span>) {</span><br><span class="line"> std::cerr << <span class="string">"Usage: ./exe integer1 integer2"</span> << std::endl;</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span>;</span><br><span class="line"> }</span><br><span class="line"> std::cerr << <span class="string">"hello c++!"</span> << std::endl;</span><br><span class="line"></span><br><span class="line"> <span class="built_in">Py_Initialize</span>();</span><br><span class="line"> <span class="built_in">PyRun_SimpleString</span>(<span class="string">"import sys"</span>);</span><br><span class="line"> <span class="built_in">PyRun_SimpleString</span>(<span class="string">"sys.path.append('./')"</span>);</span><br><span class="line"></span><br><span class="line"> <span class="built_in">PyRun_SimpleString</span>(<span class="string">"print 'hello python!'"</span>);</span><br><span class="line"> PyObject* moduleName = <span class="built_in">PyString_FromString</span>(<span class="string">"test_add"</span>);</span><br><span class="line"> PyObject* pModule = <span class="built_in">PyImport_Import</span>(moduleName);</span><br><span class="line"> <span class="keyword">if</span> (!pModule) {</span><br><span class="line"> std::cerr << <span class="string">"[ERROR] Python get module failed."</span> << std::endl;</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span>;</span><br><span class="line"> }</span><br><span class="line"> PyObject* pv = <span class="built_in">PyObject_GetAttrString</span>(pModule, <span class="string">"add"</span>);</span><br><span class="line"> <span class="keyword">if</span> (!pv || !<span class="built_in">PyCallable_Check</span>(pv)) {</span><br><span class="line"> std::cerr << <span class="string">"[ERROR] Can't find function (add)"</span> << std::endl;</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> PyObject* args = <span class="built_in">PyTuple_New</span>(<span class="number">2</span>);</span><br><span class="line"> PyObject* arg1 = <span class="built_in">PyInt_FromLong</span>(<span class="built_in">atoi</span>(argv[<span class="number">1</span>]));</span><br><span class="line"> PyObject* arg2 = <span class="built_in">PyInt_FromLong</span>(<span class="built_in">atoi</span>(argv[<span class="number">2</span>]));</span><br><span class="line"> <span class="built_in">PyTuple_SetItem</span>(args, <span class="number">0</span>, arg1);</span><br><span class="line"> <span class="built_in">PyTuple_SetItem</span>(args, <span class="number">1</span>, arg2);</span><br><span class="line"></span><br><span class="line"> PyObject* pRet = <span class="built_in">PyObject_CallObject</span>(pv, args);</span><br><span class="line"> <span class="keyword">if</span> (!pRet) {</span><br><span class="line"> std::cerr << <span class="string">"[ERROR] Call funftion (add) failed"</span> << std::endl;</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="type">long</span> result = <span class="built_in">PyInt_AsLong</span>(pRet);</span><br><span class="line"> std::cout << <span class="string">"result: "</span> << result << std::endl;</span><br><span class="line"></span><br><span class="line"> <span class="built_in">Py_Finalize</span>();</span><br><span class="line"> <span class="keyword">return</span> <span class="number">0</span>;</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>编译:</p><figure class="highlight txt"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">g++ test.cpp -o test -lpython</span><br></pre></td></tr></table></figure><p>执行:./test 3 4</p><figure class="highlight txt"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">hello c++!</span><br><span class="line">hello python!</span><br><span class="line">result: 7</span><br></pre></td></tr></table></figure><h2 id="q学习的一个例子">Q学习的一个例子</h2><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># tree.py</span></span><br><span class="line"><span class="string">"""</span></span><br><span class="line"><span class="string">author: Houjiang Chen</span></span><br><span class="line"><span class="string">"""</span></span><br><span class="line"><span class="keyword">import</span> random</span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">q_learning</span>:</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, states, actions</span>):</span><br><span class="line"> self.states = states</span><br><span class="line"> self.actions = actions</span><br><span class="line"> self.eps = <span class="number">0.1</span></span><br><span class="line"> self.alpha = <span class="number">0.1</span></span><br><span class="line"> self.q_table = [[<span class="number">0</span> <span class="keyword">for</span> j <span class="keyword">in</span> <span class="built_in">range</span>(actions)] <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(states)]</span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">get_action</span>(<span class="params">self, current_state</span>):</span><br><span class="line"> max_action = self.q_table[current_state].index(<span class="built_in">max</span>(self.q_table[current_state]))</span><br><span class="line"> <span class="keyword">if</span> random.uniform(<span class="number">0</span>, <span class="number">1</span>) > self.eps:</span><br><span class="line"> <span class="keyword">return</span> max_action</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> rest = [i <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(self.q_table[current_state])) <span class="keyword">if</span> i != max_action]</span><br><span class="line"> index = random.randint(<span class="number">0</span>, <span class="built_in">len</span>(rest) - <span class="number">1</span>)</span><br><span class="line"> <span class="keyword">return</span> rest[index]</span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">update</span>(<span class="params">self, current_state, action, next_state, reward, final</span>):</span><br><span class="line"> <span class="keyword">if</span> <span class="keyword">not</span> final:</span><br><span class="line"> reward = reward + <span class="built_in">max</span>(self.q_table[next_state])</span><br><span class="line"> self.q_table[current_state][action] += self.alpha * (reward - self.q_table[current_state][action])</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">environment</span>:</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line"> self.level = <span class="number">2</span></span><br><span class="line"> self.actions = <span class="number">2</span></span><br><span class="line"> self.states = self.actions ** (self.level + <span class="number">1</span>) - <span class="number">1</span></span><br><span class="line"> self.final_states = self.actions ** self.level</span><br><span class="line"> self.reward = {<span class="number">0</span> : [<span class="number">10</span>, -<span class="number">10</span>], <span class="number">1</span> : [<span class="number">50</span>, <span class="number">100</span>], <span class="number">2</span> : [<span class="number">100</span>, <span class="number">150</span>]}</span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">next</span>(<span class="params">self, current_state, action</span>):</span><br><span class="line"> <span class="string">"""action: 0 or 1</span></span><br><span class="line"><span class="string"> return: next_state reward, is_final</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> <span class="built_in">next</span> = <span class="number">2</span> * current_state + (action + <span class="number">1</span>)</span><br><span class="line"> <span class="keyword">if</span> <span class="built_in">next</span> >= self.states - self.final_states:</span><br><span class="line"> <span class="keyword">return</span> <span class="literal">None</span>, self.reward[current_state][action], <span class="literal">True</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">next</span>, self.reward[current_state][action], <span class="literal">False</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">reset</span>(<span class="params">self</span>):</span><br><span class="line"> <span class="keyword">return</span> random.randint(<span class="number">0</span>, self.states - self.final_states - <span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">main</span>():</span><br><span class="line"> env = environment()</span><br><span class="line"> agent = q_learning(env.states, env.actions)</span><br><span class="line"></span><br><span class="line"> episode = <span class="number">0</span></span><br><span class="line"> <span class="keyword">while</span> episode < <span class="number">10000</span>:</span><br><span class="line"> episode += <span class="number">1</span></span><br><span class="line"> <span class="built_in">print</span> <span class="string">"episode: %d"</span> % episode</span><br><span class="line"> current_state = env.reset()</span><br><span class="line"> <span class="keyword">while</span> <span class="literal">True</span>:</span><br><span class="line"> action = agent.get_action(current_state)</span><br><span class="line"> next_state, reward, final = env.<span class="built_in">next</span>(current_state, action)</span><br><span class="line"> agent.update(current_state, action, next_state, reward, final)</span><br><span class="line"> <span class="keyword">if</span> final:</span><br><span class="line"> <span class="keyword">break</span></span><br><span class="line"> current_state = next_state</span><br><span class="line"></span><br><span class="line"> <span class="built_in">print</span> agent.q_table</span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> __name__ == <span class="string">'__main__'</span>:</span><br><span class="line"> main()</span><br></pre></td></tr></table></figure><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><stdio.h></span></span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><iostream></span></span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string">"python2.7/Python.h"</span></span></span><br><span class="line"></span><br><span class="line"><span class="function">PyObject* <span class="title">New_PyInstance</span><span class="params">(PyObject* cls, PyObject* args)</span> </span>{</span><br><span class="line"> PyObject* pInstance = <span class="built_in">PyInstance_New</span>(cls, args, <span class="literal">NULL</span>);</span><br><span class="line"> <span class="keyword">if</span> (!pInstance) {</span><br><span class="line"> std::cerr << <span class="string">"new instance failed"</span> << std::endl;</span><br><span class="line"> <span class="built_in">exit</span>(<span class="number">1</span>);</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span> pInstance;</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="type">int</span> <span class="title">main</span><span class="params">(<span class="type">int</span> argc, <span class="type">char</span>* argv[])</span> </span>{</span><br><span class="line"> <span class="built_in">Py_Initialize</span>();</span><br><span class="line"> <span class="built_in">PyRun_SimpleString</span>(<span class="string">"import sys"</span>);</span><br><span class="line"> <span class="built_in">PyRun_SimpleString</span>(<span class="string">"sys.path.append('./')"</span>);</span><br><span class="line"></span><br><span class="line"> PyObject* moduleName = <span class="built_in">PyString_FromString</span>(<span class="string">"tree"</span>);</span><br><span class="line"> PyObject* pModule = <span class="built_in">PyImport_Import</span>(moduleName);</span><br><span class="line"> <span class="keyword">if</span> (!pModule) {</span><br><span class="line"> std::cerr << <span class="string">"[ERROR] Python get module failed."</span> << std::endl;</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span>;</span><br><span class="line"> }</span><br><span class="line"> PyObject* pEnv = <span class="built_in">PyObject_GetAttrString</span>(pModule, <span class="string">"environment"</span>);</span><br><span class="line"> <span class="keyword">if</span> (!pEnv) {</span><br><span class="line"> std::cerr << <span class="string">"[ERROR] Can't find class (environment)"</span> << std::endl;</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> PyObject* pEnvObject = <span class="built_in">New_PyInstance</span>(pEnv, <span class="literal">NULL</span>);</span><br><span class="line"> PyObject* pEnvLevel = <span class="built_in">PyObject_GetAttrString</span>(pEnvObject, <span class="string">"level"</span>);</span><br><span class="line"> <span class="keyword">if</span> (!pEnvLevel) {</span><br><span class="line"> std::cerr << <span class="string">"[ERROR] Env has no attr level"</span> << std::endl;</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span>;</span><br><span class="line"> }</span><br><span class="line"> PyObject* pEnvActions = <span class="built_in">PyObject_GetAttrString</span>(pEnvObject, <span class="string">"actions"</span>);</span><br><span class="line"> PyObject* pEnvStates = <span class="built_in">PyObject_GetAttrString</span>(pEnvObject, <span class="string">"states"</span>);</span><br><span class="line"> PyObject* pEnvFinalState = <span class="built_in">PyObject_GetAttrString</span>(pEnvObject, <span class="string">"final_states"</span>);</span><br><span class="line"></span><br><span class="line"> <span class="type">int</span> level = <span class="built_in">PyInt_AsLong</span>(pEnvLevel);</span><br><span class="line"> <span class="type">int</span> actions = <span class="built_in">PyInt_AsLong</span>(pEnvActions);</span><br><span class="line"> <span class="type">int</span> states = <span class="built_in">PyInt_AsLong</span>(pEnvStates);</span><br><span class="line"> <span class="type">int</span> final_state = <span class="built_in">PyInt_AsLong</span>(pEnvFinalState);</span><br><span class="line"></span><br><span class="line"> std::cout << <span class="string">"env level: "</span> << level << std::endl;</span><br><span class="line"> std::cout << <span class="string">"env actions: "</span> << actions << std::endl;</span><br><span class="line"> std::cout << <span class="string">"env states: "</span> << states << std::endl;</span><br><span class="line"> std::cout << <span class="string">"env final_state: "</span> << final_state << std::endl;</span><br><span class="line"></span><br><span class="line"> PyObject* pLearn = <span class="built_in">PyObject_GetAttrString</span>(pModule, <span class="string">"q_learning"</span>);</span><br><span class="line"> PyObject* pLearnArgs = <span class="built_in">Py_BuildValue</span>(<span class="string">"ii"</span>, states, actions);</span><br><span class="line"> PyObject* pLearnObject = <span class="built_in">New_PyInstance</span>(pLearn, pLearnArgs);</span><br><span class="line"> PyObject* pLearnStates = <span class="built_in">PyObject_GetAttrString</span>(pLearnObject, <span class="string">"states"</span>);</span><br><span class="line"> PyObject* pLearnActions = <span class="built_in">PyObject_GetAttrString</span>(pLearnObject, <span class="string">"actions"</span>);</span><br><span class="line"> PyObject* pLearnEps = <span class="built_in">PyObject_GetAttrString</span>(pLearnObject, <span class="string">"eps"</span>);</span><br><span class="line"></span><br><span class="line"> <span class="type">int</span> learn_states = <span class="built_in">PyInt_AsLong</span>(pLearnStates);</span><br><span class="line"> <span class="type">int</span> learn_actions = <span class="built_in">PyInt_AsLong</span>(pLearnActions);</span><br><span class="line"> <span class="type">float</span> learn_eps = <span class="built_in">PyFloat_AsDouble</span>(pLearnEps);</span><br><span class="line"></span><br><span class="line"> std::cout << <span class="string">"learn_states: "</span> << learn_states << std::endl;</span><br><span class="line"> std::cout << <span class="string">"learn_actions: "</span> << learn_actions << std::endl;</span><br><span class="line"> std::cout << <span class="string">"learn_eps: "</span> << learn_eps << std::endl;</span><br><span class="line"></span><br><span class="line"> PyObject* pEnvResetFunc = <span class="built_in">PyObject_GetAttrString</span>(pEnvObject, <span class="string">"reset"</span>);</span><br><span class="line"> PyObject* pEnvNextFunc = <span class="built_in">PyObject_GetAttrString</span>(pEnvObject, <span class="string">"next"</span>);</span><br><span class="line"> PyObject* pLearnGetActionFunc = <span class="built_in">PyObject_GetAttrString</span>(pLearnObject, <span class="string">"get_action"</span>);</span><br><span class="line"> PyObject* pLearnUpdateFunc = <span class="built_in">PyObject_GetAttrString</span>(pLearnObject, <span class="string">"update"</span>);</span><br><span class="line"> <span class="keyword">if</span> (!pEnvNextFunc) {</span><br><span class="line"> std::cerr << <span class="string">"[ERROR] env has no function named next"</span> << std::endl;</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> std::cout << std::endl;</span><br><span class="line"> <span class="type">uint64_t</span> episode = <span class="number">0</span>;</span><br><span class="line"> <span class="keyword">for</span> (episode = <span class="number">0</span>; episode < <span class="number">10000</span>; ++episode) {</span><br><span class="line"> <span class="keyword">if</span> (episode % <span class="number">100</span> == <span class="number">0</span>)</span><br><span class="line"> std::cout << <span class="string">"episode: "</span> << episode << std::endl;</span><br><span class="line"> PyObject* current_state = <span class="built_in">PyEval_CallObject</span>(pEnvResetFunc, <span class="literal">NULL</span>);</span><br><span class="line"> <span class="keyword">while</span> (<span class="literal">true</span>) {</span><br><span class="line"> PyObject* args1 = <span class="built_in">PyTuple_New</span>(<span class="number">1</span>);</span><br><span class="line"> PyObject* args2 = <span class="built_in">PyTuple_New</span>(<span class="number">2</span>);</span><br><span class="line"> <span class="built_in">PyTuple_SetItem</span>(args1, <span class="number">0</span>, current_state);</span><br><span class="line"> PyObject* action = <span class="built_in">PyEval_CallObject</span>(pLearnGetActionFunc, args1);</span><br><span class="line"> <span class="built_in">PyTuple_SetItem</span>(args2, <span class="number">0</span>, current_state);</span><br><span class="line"> <span class="built_in">PyTuple_SetItem</span>(args2, <span class="number">1</span>, action);</span><br><span class="line"> PyObject* ret = <span class="built_in">PyEval_CallObject</span>(pEnvNextFunc, args2);</span><br><span class="line"> PyObject* next_state = <span class="built_in">PyTuple_GetItem</span>(ret, <span class="number">0</span>);</span><br><span class="line"> PyObject* <span class="keyword">final</span> = <span class="built_in">PyTuple_GetItem</span>(ret ,<span class="number">2</span>);</span><br><span class="line"> PyObject* args3 = <span class="built_in">PyTuple_New</span>(<span class="number">5</span>);</span><br><span class="line"> <span class="built_in">PyTuple_SetItem</span>(args3, <span class="number">0</span>, current_state);</span><br><span class="line"> <span class="built_in">PyTuple_SetItem</span>(args3, <span class="number">1</span>, action);</span><br><span class="line"> <span class="built_in">PyTuple_SetItem</span>(args3, <span class="number">2</span>, next_state);</span><br><span class="line"> <span class="built_in">PyTuple_SetItem</span>(args3, <span class="number">3</span>, <span class="built_in">PyTuple_GetItem</span>(ret, <span class="number">1</span>));</span><br><span class="line"> <span class="built_in">PyTuple_SetItem</span>(args3, <span class="number">4</span>, <span class="keyword">final</span>);</span><br><span class="line"></span><br><span class="line"> <span class="built_in">PyEval_CallObject</span>(pLearnUpdateFunc, args3);</span><br><span class="line"> <span class="keyword">if</span> (<span class="built_in">PyObject_IsTrue</span>(<span class="keyword">final</span>)) {</span><br><span class="line"> <span class="keyword">break</span>;</span><br><span class="line"> }</span><br><span class="line"> current_state = next_state;</span><br><span class="line"> <span class="keyword">if</span> (args3)</span><br><span class="line"> <span class="built_in">Py_DECREF</span>(args3);</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> PyObject* pLearnQTable = <span class="built_in">PyObject_GetAttrString</span>(pLearnObject, <span class="string">"q_table"</span>);</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">int</span> i = <span class="number">0</span>; i < <span class="built_in">PyList_Size</span>(pLearnQTable); ++i) {</span><br><span class="line"> std::cout << <span class="string">"state "</span> << i << std::endl;</span><br><span class="line"> PyObject* term = <span class="built_in">PyList_GetItem</span>(pLearnQTable, i);</span><br><span class="line"> <span class="keyword">if</span> (<span class="built_in">PyList_Check</span>(term)) {</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">int</span> j = <span class="number">0</span>; j < <span class="built_in">PyList_Size</span>(term); ++j) {</span><br><span class="line"> std::cout << <span class="string">" direct: "</span> << j << <span class="string">", "</span> << <span class="string">"Qvalue: "</span></span><br><span class="line"> << <span class="built_in">PyFloat_AsDouble</span>(<span class="built_in">PyList_GetItem</span>(term, j)) << std::endl;</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> <span class="built_in">Py_Finalize</span>();</span><br><span class="line"> <span class="keyword">return</span> <span class="number">0</span>;</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>编译:</p><figure class="highlight txt"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">g++ test.cpp -o test -I../python2.7.12/include -L../python2.7.12/lib -lpython2.7</span><br></pre></td></tr></table></figure><p>执行:./test</p><figure class="highlight txt"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br><span class="line">126</span><br><span class="line">127</span><br><span class="line">128</span><br><span class="line">129</span><br></pre></td><td class="code"><pre><span class="line">env level: 2</span><br><span class="line">env actions: 2</span><br><span class="line">env states: 7</span><br><span class="line">env final_state: 4</span><br><span class="line">learn_states: 7</span><br><span class="line">learn_actions: 2</span><br><span class="line">learn_eps: 0.1</span><br><span class="line"></span><br><span class="line">episode: 0</span><br><span class="line">episode: 100</span><br><span class="line">episode: 200</span><br><span class="line">episode: 300</span><br><span class="line">episode: 400</span><br><span class="line">episode: 500</span><br><span class="line">episode: 600</span><br><span class="line">episode: 700</span><br><span class="line">episode: 800</span><br><span class="line">episode: 900</span><br><span class="line">episode: 1000</span><br><span class="line">episode: 1100</span><br><span class="line">episode: 1200</span><br><span class="line">episode: 1300</span><br><span class="line">episode: 1400</span><br><span class="line">episode: 1500</span><br><span class="line">episode: 1600</span><br><span class="line">episode: 1700</span><br><span class="line">episode: 1800</span><br><span class="line">episode: 1900</span><br><span class="line">episode: 2000</span><br><span class="line">episode: 2100</span><br><span class="line">episode: 2200</span><br><span class="line">episode: 2300</span><br><span class="line">episode: 2400</span><br><span class="line">episode: 2500</span><br><span class="line">episode: 2600</span><br><span class="line">episode: 2700</span><br><span class="line">episode: 2800</span><br><span class="line">episode: 2900</span><br><span class="line">episode: 3000</span><br><span class="line">episode: 3100</span><br><span class="line">episode: 3200</span><br><span class="line">episode: 3300</span><br><span class="line">episode: 3400</span><br><span class="line">episode: 3500</span><br><span class="line">episode: 3600</span><br><span class="line">episode: 3700</span><br><span class="line">episode: 3800</span><br><span class="line">episode: 3900</span><br><span class="line">episode: 4000</span><br><span class="line">episode: 4100</span><br><span class="line">episode: 4200</span><br><span class="line">episode: 4300</span><br><span class="line">episode: 4400</span><br><span class="line">episode: 4500</span><br><span class="line">episode: 4600</span><br><span class="line">episode: 4700</span><br><span class="line">episode: 4800</span><br><span class="line">episode: 4900</span><br><span class="line">episode: 5000</span><br><span class="line">episode: 5100</span><br><span class="line">episode: 5200</span><br><span class="line">episode: 5300</span><br><span class="line">episode: 5400</span><br><span class="line">episode: 5500</span><br><span class="line">episode: 5600</span><br><span class="line">episode: 5700</span><br><span class="line">episode: 5800</span><br><span class="line">episode: 5900</span><br><span class="line">episode: 6000</span><br><span class="line">episode: 6100</span><br><span class="line">episode: 6200</span><br><span class="line">episode: 6300</span><br><span class="line">episode: 6400</span><br><span class="line">episode: 6500</span><br><span class="line">episode: 6600</span><br><span class="line">episode: 6700</span><br><span class="line">episode: 6800</span><br><span class="line">episode: 6900</span><br><span class="line">episode: 7000</span><br><span class="line">episode: 7100</span><br><span class="line">episode: 7200</span><br><span class="line">episode: 7300</span><br><span class="line">episode: 7400</span><br><span class="line">episode: 7500</span><br><span class="line">episode: 7600</span><br><span class="line">episode: 7700</span><br><span class="line">episode: 7800</span><br><span class="line">episode: 7900</span><br><span class="line">episode: 8000</span><br><span class="line">episode: 8100</span><br><span class="line">episode: 8200</span><br><span class="line">episode: 8300</span><br><span class="line">episode: 8400</span><br><span class="line">episode: 8500</span><br><span class="line">episode: 8600</span><br><span class="line">episode: 8700</span><br><span class="line">episode: 8800</span><br><span class="line">episode: 8900</span><br><span class="line">episode: 9000</span><br><span class="line">episode: 9100</span><br><span class="line">episode: 9200</span><br><span class="line">episode: 9300</span><br><span class="line">episode: 9400</span><br><span class="line">episode: 9500</span><br><span class="line">episode: 9600</span><br><span class="line">episode: 9700</span><br><span class="line">episode: 9800</span><br><span class="line">episode: 9900</span><br><span class="line">state 0</span><br><span class="line"> direct: 0, Qvalue: 110</span><br><span class="line"> direct: 1, Qvalue: 140</span><br><span class="line">state 1</span><br><span class="line"> direct: 0, Qvalue: 50</span><br><span class="line"> direct: 1, Qvalue: 100</span><br><span class="line">state 2</span><br><span class="line"> direct: 0, Qvalue: 100</span><br><span class="line"> direct: 1, Qvalue: 150</span><br><span class="line">state 3</span><br><span class="line"> direct: 0, Qvalue: 0</span><br><span class="line"> direct: 1, Qvalue: 0</span><br><span class="line">state 4</span><br><span class="line"> direct: 0, Qvalue: 0</span><br><span class="line"> direct: 1, Qvalue: 0</span><br><span class="line">state 5</span><br><span class="line"> direct: 0, Qvalue: 0</span><br><span class="line"> direct: 1, Qvalue: 0</span><br><span class="line">state 6</span><br><span class="line"> direct: 0, Qvalue: 0</span><br><span class="line"> direct: 1, Qvalue: 0</span><br></pre></td></tr></table></figure><h2 id="参考资料">参考资料</h2><p>Python/C API Reference Manual:https://docs.python.org/2/c-api/index.html</p>]]></content>
<summary type="html"><p>由于需要在组内新开发的一套机器学习框架上开发一个强化学习的demo,但目前开源的一些游戏环境都只提供了python接口,比如Gym。如果要使用Gym去做在线训练的话,就需要在C++代码中调用Python接口,因此找了些例子学习了一下如何使用Python
C API。当然Python C
API不是唯一的方式,也可以使用boost的Python模块,有时间再研究。</p></summary>
<category term="code" scheme="https://hjchen2.github.io/categories/code/"/>
<category term="c++" scheme="https://hjchen2.github.io/tags/c/"/>
<category term="python" scheme="https://hjchen2.github.io/tags/python/"/>
<category term="embedding" scheme="https://hjchen2.github.io/tags/embedding/"/>
</entry>
<entry>
<title>多节点异步更新中momentum的影响</title>
<link href="https://hjchen2.github.io/2017/06/21/ASGD%E4%B8%ADmomentum%E7%9A%84%E5%BD%B1%E5%93%8D/"/>
<id>https://hjchen2.github.io/2017/06/21/ASGD%E4%B8%ADmomentum%E7%9A%84%E5%BD%B1%E5%93%8D/</id>
<published>2017-06-21T04:31:08.000Z</published>
<updated>2023-01-03T14:04:03.846Z</updated>
<content type="html"><![CDATA[<p>这几天的主要工作是将caffe移植到组内新开发的某个计算框架,在验证正确性时遇到一个问题。由于计算框架只支持异步更新的方式,因此采用全异步SGD算法训练Alexnet时非常容易发散。另外调研了一下近期发布的异步更新算法DC-ASGD,实验结果只能说对收敛有些正向效果,仍无法解决训练发散的问题。在另外一个DNN的网络上发现在多机时momentum对收敛结果有较大影响,momentum会导致收敛出现较大波动。</p><span id="more"></span><p>网上找了一圈,似乎也就这个有些参考价值:http://stanford.edu/~imit/tuneyourmomentum/theory/</p><p>看来近期得做一些调momentum和学习率的实验了。。。</p>]]></content>
<summary type="html"><p>这几天的主要工作是将caffe移植到组内新开发的某个计算框架,在验证正确性时遇到一个问题。由于计算框架只支持异步更新的方式,因此采用全异步SGD算法训练Alexnet时非常容易发散。另外调研了一下近期发布的异步更新算法DC-ASGD,实验结果只能说对收敛有些正向效果,仍无法解决训练发散的问题。在另外一个DNN的网络上发现在多机时momentum对收敛结果有较大影响,momentum会导致收敛出现较大波动。</p></summary>
<category term="deep learning" scheme="https://hjchen2.github.io/categories/deep-learning/"/>
<category term="caffe" scheme="https://hjchen2.github.io/tags/caffe/"/>
<category term="deep learning" scheme="https://hjchen2.github.io/tags/deep-learning/"/>
<category term="momentum" scheme="https://hjchen2.github.io/tags/momentum/"/>
</entry>
<entry>
<title>强化学习(二)</title>
<link href="https://hjchen2.github.io/2017/04/25/%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0%EF%BC%88%E4%BA%8C%EF%BC%89/"/>
<id>https://hjchen2.github.io/2017/04/25/%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0%EF%BC%88%E4%BA%8C%EF%BC%89/</id>
<published>2017-04-25T04:31:08.000Z</published>
<updated>2023-05-19T03:51:57.657Z</updated>
<content type="html"><![CDATA[<h2 id="dqn">DQN</h2><p>前面我们讲到TD算法结合了动态规划和蒙特卡洛算法的优点,不依赖具体的环境模型,并且更新时采用滑动平均的方式,因此单步就能更新,而不需要生成整个episode,在非episode情况下仍然适用。TD算法又分为onpolicy的sarsa算法和off policy的Q learning算法,其中Qlearning算法直接使用下一状态的最大动作值函数进行更新,加快了算法收敛速度,因此Qlearning算法在实际应用中更加普遍。</p><span id="more"></span><h3 id="q-learning例子">Q learning例子</h3><p>我们用一个例子来说明Qlearning算法的过程。下图是一个二叉树表示的路径规划问题,每一个节点代表环境中的一个状态,叶子节点表示终止状态,每个非叶子节点都可以选择向上或向下的动作,然后转移到下一个节点,并获得相应的得分。</p><div data-align="center"><img src="https://raw.githubusercontent.com/hjchen2/personal/master/blog/pictures/9930b76dc4a4c37e188ea6363fe6603b.png?raw=true" width=600></div><p></br></p><p>首先初始化所有状态动作对的动作值函数:<spanclass="math inline">\(Q(S_{i},a)=0, \forall i\in[1,6],a\in[上,下]\)</span>,并且初始化<span class="math inline">\(\epsilon =0.1,\alpha = 0.1\)</span>。</p><ul><li><p>随机选择一个初始状态<spanclass="math inline">\(S\)</span>,假设为<spanclass="math inline">\(S_0\)</span><br />根据<spanclass="math inline">\(\epsilon-greedy\)</span>策略选择一个动作,假设为上,转移到状态<spanclass="math inline">\(S_1\)</span>,那么更新<spanclass="math inline">\(Q(S_0,上)=Q(S_0,上)+\alpha\cdot(R_{1}+\max_aQ(S_1,a)-Q(S_0,上))=0+0.1\cdot(10+0-0)=1\)</span>,接下来继续根据<spanclass="math inline">\(\epsilon-greedy\)</span>策略选择下一个动作,比如下,并且转移到终止状态<spanclass="math inline">\(S_4\)</span>,因此<spanclass="math inline">\(Q(S_1,下)=Q(S_0,下)+\alpha\cdot(R_{2}+\max_aQ(S_4,a)-Q(S_1,下))=0+0.1\cdot(100+0-0)=10\)</span>。</p></li><li><p>随机选择一个初始状态<spanclass="math inline">\(S\)</span>,假设为<spanclass="math inline">\(S_2\)</span><br />根据<spanclass="math inline">\(\epsilon-greedy\)</span>策略选择一个动作,假设为上,转移到终止状态<spanclass="math inline">\(S_5\)</span>,则更新<spanclass="math inline">\(Q(S_2,上)=0+0.1\cdot(100+0-0)=10\)</span></p></li><li><p>随机选择一个初始状态<spanclass="math inline">\(S\)</span>,假设为<spanclass="math inline">\(S_0\)</span><br />根据<spanclass="math inline">\(\epsilon-greedy\)</span>策略选择一个动作,假设为上,转移到状态<spanclass="math inline">\(S_1\)</span>,则更新<spanclass="math inline">\(Q(S_0,上)=1+0.1\cdot(10+10-1)=2.9\)</span>,选择下一个动作,比如上,则<spanclass="math inline">\(Q(S_1,上)=0+0.1\cdot(50+0-0)=5\)</span></p></li><li><p>随机选择一个初始状态<spanclass="math inline">\(S\)</span>,假设为<spanclass="math inline">\(S_0\)</span><br />根据<spanclass="math inline">\(\epsilon-greedy\)</span>策略选择一个动作,假设为上,转移到状态<spanclass="math inline">\(S_1\)</span>,则更新<spanclass="math inline">\(Q(S_0,上)=2.9+0.1\cdot(10+10-2.9)=4.61\)</span>,选择下一个动作,比如下,则<spanclass="math inline">\(Q(S_1,下)=10+0.1\cdot(100+0-10)=19\)</span></p></li><li><p>…</p></li></ul><p>下面是该例子的python实现:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br></pre></td><td class="code"><pre><span class="line"><span class="string">"""</span></span><br><span class="line"><span class="string">author: Houjiang Chen</span></span><br><span class="line"><span class="string">"""</span></span><br><span class="line"><span class="keyword">import</span> random</span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">q_learning</span>(<span class="title class_ inherited__">object</span>):</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, states, actions</span>):</span><br><span class="line"> self.states = states</span><br><span class="line"> self.actions = actions</span><br><span class="line"> self.eps = <span class="number">0.1</span></span><br><span class="line"> self.alpha = <span class="number">0.1</span></span><br><span class="line"> self.q_table = [[<span class="number">0</span> <span class="keyword">for</span> j <span class="keyword">in</span> <span class="built_in">range</span>(actions)] <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(states)]</span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">get_action</span>(<span class="params">self, current_state</span>):</span><br><span class="line"> max_action = self.q_table[current_state].index(<span class="built_in">max</span>(self.q_table[current_state]))</span><br><span class="line"> <span class="keyword">if</span> random.uniform(<span class="number">0</span>, <span class="number">1</span>) > self.eps:</span><br><span class="line"> <span class="keyword">return</span> max_action</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> rest = [i <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(self.q_table[current_state])) <span class="keyword">if</span> i != max_action]</span><br><span class="line"> index = random.randint(<span class="number">0</span>, <span class="built_in">len</span>(rest) - <span class="number">1</span>)</span><br><span class="line"> <span class="keyword">return</span> rest[index]</span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">update</span>(<span class="params">self, current_state, action, next_state, reward, final</span>):</span><br><span class="line"> <span class="keyword">if</span> final != <span class="number">1</span>:</span><br><span class="line"> reward = reward + <span class="built_in">max</span>(self.q_table[next_state])</span><br><span class="line"> self.q_table[current_state][action] += self.alpha * (reward - self.q_table[current_state][action])</span><br><span class="line"> </span><br><span class="line"> </span><br><span class="line"><span class="keyword">class</span> <span class="title class_">environment</span>(<span class="title class_ inherited__">object</span>):</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line"> self.level = <span class="number">2</span></span><br><span class="line"> self.actions = <span class="number">2</span></span><br><span class="line"> self.states = self.actions ** (self.level + <span class="number">1</span>) - <span class="number">1</span></span><br><span class="line"> self.final_states = self.actions ** self.level</span><br><span class="line"> self.reward = {<span class="number">0</span> : [<span class="number">10</span>, -<span class="number">10</span>], <span class="number">1</span> : [<span class="number">50</span>, <span class="number">100</span>], <span class="number">2</span> : [<span class="number">100</span>, <span class="number">150</span>]}</span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">next</span>(<span class="params">self, current_state, action</span>):</span><br><span class="line"> <span class="string">"""action: 0 or 1</span></span><br><span class="line"><span class="string"> return: next_state, reward, is_final</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> <span class="built_in">next</span> = <span class="number">2</span> * current_state + (action + <span class="number">1</span>)</span><br><span class="line"> <span class="keyword">if</span> <span class="built_in">next</span> >= self.states - self.final_states:</span><br><span class="line"> <span class="keyword">return</span> <span class="literal">None</span>, self.reward[current_state][action], <span class="number">1</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">next</span>, self.reward[current_state][action], <span class="number">0</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">reset</span>(<span class="params">self</span>):</span><br><span class="line"> <span class="keyword">return</span> random.randint(<span class="number">0</span>, self.states - self.final_states - <span class="number">1</span>)</span><br><span class="line"> </span><br><span class="line"></span><br><span class="line">env = environment()</span><br><span class="line">agent = q_learning(env.states, env.actions)</span><br><span class="line"></span><br><span class="line">episode = <span class="number">0</span></span><br><span class="line"><span class="keyword">while</span> episode < <span class="number">100000</span>:</span><br><span class="line"> episode += <span class="number">1</span></span><br><span class="line"> <span class="built_in">print</span> <span class="string">"episode: %d"</span> % episode</span><br><span class="line"> current_state = env.reset()</span><br><span class="line"> <span class="keyword">while</span> <span class="literal">True</span>:</span><br><span class="line"> action = agent.get_action(current_state)</span><br><span class="line"> next_state, reward, final = env.<span class="built_in">next</span>(current_state, action)</span><br><span class="line"> agent.update(current_state, action, next_state, reward, final)</span><br><span class="line"> <span class="keyword">if</span> final:</span><br><span class="line"> <span class="keyword">break</span></span><br><span class="line"> current_state = next_state</span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span> agent.q_table</span><br><span class="line"> </span><br></pre></td></tr></table></figure><p>最终收敛结果为:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">[[<span class="number">109.99999999999989</span>, <span class="number">139.99999999999977</span>], </span><br><span class="line">[<span class="number">49.99999999999997</span>, <span class="number">99.99999999999994</span>], </span><br><span class="line">[<span class="number">99.99999999999994</span>, <span class="number">149.9999999999999</span>], </span><br><span class="line">[<span class="number">0</span>, <span class="number">0</span>], [<span class="number">0</span>, <span class="number">0</span>], [<span class="number">0</span>, <span class="number">0</span>], [<span class="number">0</span>, <span class="number">0</span>]]</span><br></pre></td></tr></table></figure><h3 id="函数逼近">函数逼近</h3><p>上面的例子中非终止状态数只有3个,每个非终止状态对应的动作只有2个,因此状态动作对总共有6个,使用表格存储完全没有问题,但实际上我们需要解决的并不是一个如此简单的问题。比如在【PlayingAtari with Deep Reinforcement Learning】中DeepMind就使用Qlearning使得agent玩Atari 2600游戏的水平超越了人类水平。在Atari2600游戏中,每个游戏画面都是一个状态,如果每个画面都是像素为84*84的256灰度图像,那么将会产生<spanclass="math inline">\(256^{84\cdot84}\)</span>个状态,用表格进行存储将会变得非常不现实。为了解决状态数爆炸的问题,通常可以使用函数逼近的方法。下面有几种函数表示的方式:</p><div data-align="center"><img src="https://github.com/hjchen2/personal/blob/master/blog/pictures/30EFF3D4-0562-4544-BFF9-D43B3EC7AFF7.png?raw=true" width=600></div><p></br></p><p>并且逼近函数的形式可以采用:</p><ul><li>Linear combinations of features</li><li>Neural network</li><li>Decision tree</li><li>Nearest neighbour</li><li>Fourier / wavelet bases</li><li>...</li></ul><p>下面我们研究的DQN(Deep Q Network)就是采用Deep neuralnetwork进行动作值函数逼近的一种方法,结构如下。</p><div data-align="center"><img src="https://github.com/hjchen2/personal/blob/master/blog/pictures/8e238f9d9836b789276e0e58d4aa1e34.png?raw=true" width=400></div><p></br></p><p>为推导方便,假设中间的Network为一层的全连接,即<spanclass="math inline">\(\hat{V}(s,a)=x(S)^{T}w=\sum_{j=1}^{n}{x_{j}(S)w_{j}}\)</span>,代价函数选择最小均方误差:<spanclass="math inline">\(J(w)=\frac{1}{2}(V(s,a)-\hat{V}(s,a))^2\)</span>,采用随机梯度下降算法进行优化。</p><p><spanclass="math display">\[\begin{split}\frac{\partial{J(w)}}{\partial{w}}&=\left(V(s,a)-\hat{V}(s,a)\right)\frac{\partial{\hat{V}(s,a)}}{\partial{w}} \\ &=\left(V(s,a)-\hat{V}(s,a)\right)x(S)\end{split}\tag{1-1}\]</span></p><p><span class="math display">\[\begin{split}w^k&=w^{k-1}+\eta\Delta(w)\\&=w^{k-1}-\eta\frac{\partial{J(w)}}{\partial{w}}\\&=w^{k-1}-\eta\left(V(s,a)-\hat{V}(s,a;w^{k})\right)x(S)\end{split}\tag{1-2}\]</span></p><p>由于我们并没有动作值函数的真实值,因此与Q learning类似,<spanclass="math inline">\(V(s,a,)\)</span>可以使用下一个状态的动作值函数进行估计,即<spanclass="math inline">\(V(s,a)=V(s,a;w^{k-1})=r+\gamma\max_{a^{'}}V(s^{'},a^{'};w^{k-1})\)</span>。</p><p>整个训练过程仍然与Q learning一样,采用<spanclass="math inline">\(\epsilon-greedy\)</span>策略选择动作,并按照公式(1-2)更新权重<spanclass="math inline">\(w\)</span>,实际上也就更新了策略的动作值函数。使用值函数逼近的方法不需要枚举每个状态动作对,突破了状态数的限制,使得Qlearning在一些复杂任务上得到广泛应用,但仍然没有解决动作数爆炸或者连续动作的问题。</p><h3 id="dqn-1">DQN</h3><p>DQN最先出现于DeepMind发表的【Playing Atari with Deep ReinforcementLearning】论文中,由于需要直接输入图像画面,因此论文中使用CNN来表示Q函数,下面简单剖析一下该论文。</p><p>使用的是典型的CNN,其结构为:</p><div data-align="center"><img src="https://github.com/hjchen2/personal/blob/master/blog/pictures/93F5C516-8E53-4F89-B03E-3EDD95DF1C76.png?raw=true" width=600></div><p></br>与一般的CNN有所不同的是,没有pooling层,因为我们这里不是做图像分类,pooling层带来的旋转和数值不变性对分类是有作用的,但在这个任务中对物体的具体位置是非常敏感的,因此移除了pooling层。</p><p>Atari原始的游戏帧为210*160像素的RGB图像,由于该任务对画面色彩不敏感,为了减少计算开销,将游戏帧预处理成84*84的灰度图像。但为了获得动态特征,最终是将前3帧图像与当前帧stack到一起组成一个4*84*84的图像作为CNN的输入,输出为每个动作对应的Q值。</p><h3 id="经验回放">经验回放</h3><p>现在我们知道可以使用Qlearning去估计每个状态的未来回报的期望,并且可以使用CNN去逼近动作值函数,也就是可以使用DQN去解决一个复杂的MDP任务。但在实际应用时会出现更新波动较大,导致收敛非常慢的问题,DeepMind因此使用了一个经验回放(ExperienceReplay)机制,就是将每步的经验数据<spanclass="math inline">\(<s,a,r,s^{'}>\)</span>存放在回放内存中,更新时都从回放内存中随机采样一个batch的数据进行更新。</p><p>经验回放机制相比标准的DQN有两个好处:首先每一步的经验数据会被保存起来,更新时可以多次使用到经验数据,使得数据利用更高效;此外直接从连续的样本中学习是低效的,因为一个episode内样本具有很强的相关性,随机挑选样本打破了这种相关性,因此减小了更新时的变化,使得更新更加稳定(注:因为同一次实验过程的样本相关性很强,不同实验之间的相关性就显得相对比较小,如果使用连续的样本进行训练,在切换到下一次实验的样本时会导致模型更新不稳定)。</p><p>由于内存大小限制,回放内存不可能将所有的经验数据都保存起来,因此只会保留最新的N组经验数据,比较久远的数据就会被遗忘。</p><h3 id="训练">训练</h3><p>DeepMind使用DQN对ATARI中七个游戏进行了实验,由于每个游戏的得分尺度不一致,因此他们将得分分为正回报、负回报和无回报,正回报得分为1,负回报得分为-1,无回报得分为0。</p><p>使用 RMSProp算法进行优化,batch size为32,采用<spanclass="math inline">\(\epsilon-greedy\)</span>行动策略,前一百万帧的<spanclass="math inline">\(\epsilon\)</span>从1线性减少到0.1,最后固定为0.1。总共训练了一千万帧,并且使用了一百万大小的回放内存。</p><p>训练过程伪代码:</p><div data-align="center"><img src="https://github.com/hjchen2/personal/blob/master/blog/pictures/1E5C7D95-519A-4B54-BF09-C27A163D12C8.png?raw=true" width=600></div><h2 id="gym使用">Gym使用</h2><h3 id="gym简介">Gym简介</h3><p>目前强化学习的研究主要由DeepMind和OpenAI两家在主导,去年底到今年初DeepMind和OpenAI相继开源了自家的3Dlearning environment平台DeepMind Lab和Universe。DeepMindLab目前给出的文档和例子都比较少,使用也稍显复杂,所以暂时可以不考虑使用。Universe包含了1000+的游戏环境,并且将程序打包在docker环境中运行,提供与Gym一致的接口。Universe的环境由一个client和一个remote组成,client是一个VNCenv,主要负责接收agent的动作,传递回报和管理本地episode的状态,remote是指在docker环境中运行的程序,remote可以运行在本地、远程服务器或在cloud上。client和remote通过VNC远程桌面系统进行交互,通过WebSocket传递回报、诊断和控制信息。</p><p>由于Universe环境提供Gym接口,而Gym是OpenAI去年4月份发布的一套开发和比较强化学习算法的toolkit。Gym本身是可以独立于Universe使用的,并且Universe和Gym中agent代码基本没有什么区别。我们下面就单独讲讲Gym接口和如何使用Gym训练自己的agent。</p><p>Gym目前提供python接口,并支持任何的计算框架,比如tensorflow、theano等。强化学习解决的是agent和环境交互的任务,agent根据当前环境状态做出某个动作,然后观察下一个状态和回报,环境根据agent的动作转移到下一个状态,并发送回报。Gym提供的实际上是环境这个角色,每个Gym环境都提供一致的接口。</p><h3 id="创建一个gym环境">创建一个Gym环境</h3><p>创建一个环境时只需要指定环境id,比如agent需要玩AtariBreakout-v0这个游戏,可以如下创建一个Breakout-v0的环境。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> gym</span><br><span class="line">env = gym.make(<span class="string">'Breakout-v0'</span>)</span><br></pre></td></tr></table></figure><h3 id="step">step</h3><p>输入agent的动作,返回4个值,分别为:</p><ul><li>observation:表示agent观察到的下一个状态,比如在一些游戏中,observation为RGB的图像</li><li>reward:表示执行输入的动作后得到的回报值</li><li>done:表示返回的observation是不是结束状态</li><li>info:调试信息,一般没什么用处</li></ul><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">next_state, reward, terminal, _ = env.step(action)</span><br></pre></td></tr></table></figure><h3 id="reset">reset</h3><p>在开始一个新的episode时,Gym环境都要reset,获得一个初始状态。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">init_state = env.reset()</span><br></pre></td></tr></table></figure><h3 id="render">render</h3><p>render是Gym用来渲染环境状态的函数,当调用该函数时会出现一个动图框。一般agent执行一个动作,环境都要渲染一次,这样就可以实时看到agent的执行情况了。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">env.render()</span><br></pre></td></tr></table></figure><h3 id="spaces">Spaces</h3><p>Gym环境有两个space属性,一个是action_space,一个是observation_space,分别表示该Gym环境下合法的动作和状态。action_space是Gym中的一个Discrete对象,Discrete对象有一个成员n,表示合法的动作数,比如Discrete(2)表示有两个合法动作,编号从0开始,因此两个动作编号为0和1。observation_space是Gym中的一个Box对象,Box的shape表示observation的数据组织方式,比如Box(210,160,3)表示合法的observation是一个210*160*3的数组,而Box(4,)表示observation是一个大小为4的向量。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">observation_space = env.observation_space <span class="comment"># observation_space: Discrete(6)</span></span><br><span class="line">action_space = env.action_space <span class="comment"># action_space: Box(210, 160, 3)</span></span><br></pre></td></tr></table></figure><h3 id="breakout-v0例子">Breakout-v0例子</h3><p>采用了github上Flood Sung的DQN实现,感谢Flood Sung大神的无私贡献。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br><span class="line">126</span><br><span class="line">127</span><br><span class="line">128</span><br><span class="line">129</span><br><span class="line">130</span><br><span class="line">131</span><br><span class="line">132</span><br><span class="line">133</span><br><span class="line">134</span><br><span class="line">135</span><br><span class="line">136</span><br><span class="line">137</span><br><span class="line">138</span><br><span class="line">139</span><br><span class="line">140</span><br><span class="line">141</span><br><span class="line">142</span><br><span class="line">143</span><br><span class="line">144</span><br><span class="line">145</span><br><span class="line">146</span><br><span class="line">147</span><br><span class="line">148</span><br><span class="line">149</span><br><span class="line">150</span><br><span class="line">151</span><br><span class="line">152</span><br><span class="line">153</span><br><span class="line">154</span><br><span class="line">155</span><br><span class="line">156</span><br><span class="line">157</span><br><span class="line">158</span><br><span class="line">159</span><br><span class="line">160</span><br><span class="line">161</span><br><span class="line">162</span><br><span class="line">163</span><br><span class="line">164</span><br><span class="line">165</span><br><span class="line">166</span><br><span class="line">167</span><br><span class="line">168</span><br><span class="line">169</span><br><span class="line">170</span><br><span class="line">171</span><br><span class="line">172</span><br><span class="line">173</span><br><span class="line">174</span><br><span class="line">175</span><br><span class="line">176</span><br><span class="line">177</span><br><span class="line">178</span><br><span class="line">179</span><br><span class="line">180</span><br><span class="line">181</span><br><span class="line">182</span><br><span class="line">183</span><br><span class="line">184</span><br><span class="line">185</span><br><span class="line">186</span><br><span class="line">187</span><br><span class="line">188</span><br><span class="line">189</span><br><span class="line">190</span><br><span class="line">191</span><br><span class="line">192</span><br><span class="line">193</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># -----------------------------</span></span><br><span class="line"><span class="comment"># File: Deep Q-Learning Algorithm</span></span><br><span class="line"><span class="comment"># Author: Flood Sung</span></span><br><span class="line"><span class="comment"># Date: 2016.3.21</span></span><br><span class="line"><span class="comment"># -----------------------------</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> tensorflow <span class="keyword">as</span> tf</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> random</span><br><span class="line"><span class="keyword">from</span> collections <span class="keyword">import</span> deque</span><br><span class="line"></span><br><span class="line"><span class="comment"># Hyper Parameters:</span></span><br><span class="line">FRAME_PER_ACTION = <span class="number">1</span></span><br><span class="line">GAMMA = <span class="number">0.99</span> <span class="comment"># decay rate of past observations</span></span><br><span class="line">OBSERVE = <span class="number">100.</span> <span class="comment"># timesteps to observe before training</span></span><br><span class="line">EXPLORE = <span class="number">200000.</span> <span class="comment"># frames over which to anneal epsilon</span></span><br><span class="line">FINAL_EPSILON = <span class="number">0</span><span class="comment">#0.001 # final value of epsilon</span></span><br><span class="line">INITIAL_EPSILON = <span class="number">0</span><span class="comment">#0.01 # starting value of epsilon</span></span><br><span class="line">REPLAY_MEMORY = <span class="number">50000</span> <span class="comment"># number of previous transitions to remember</span></span><br><span class="line">BATCH_SIZE = <span class="number">32</span> <span class="comment"># size of minibatch</span></span><br><span class="line">UPDATE_TIME = <span class="number">100</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">BrainDQN</span>:</span><br><span class="line"><span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self,actions</span>):</span><br><span class="line"><span class="comment"># init replay memory</span></span><br><span class="line">self.replayMemory = deque()</span><br><span class="line"><span class="comment"># init some parameters</span></span><br><span class="line">self.timeStep = <span class="number">0</span></span><br><span class="line">self.epsilon = INITIAL_EPSILON</span><br><span class="line">self.actions = actions</span><br><span class="line"><span class="comment"># init Q network</span></span><br><span class="line">self.stateInput,self.QValue,self.W_conv1,self.b_conv1,self.W_conv2,self.b_conv2,self.W_conv3,self.b_conv3,self.W_fc1,self.b_fc1,self.W_fc2,self.b_fc2 = self.createQNetwork()</span><br><span class="line"></span><br><span class="line"><span class="comment"># init Target Q Network</span></span><br><span class="line">self.stateInputT,self.QValueT,self.W_conv1T,self.b_conv1T,self.W_conv2T,self.b_conv2T,self.W_conv3T,self.b_conv3T,self.W_fc1T,self.b_fc1T,self.W_fc2T,self.b_fc2T = self.createQNetwork()</span><br><span class="line"></span><br><span class="line">self.copyTargetQNetworkOperation = [self.W_conv1T.assign(self.W_conv1),self.b_conv1T.assign(self.b_conv1),self.W_conv2T.assign(self.W_conv2),self.b_conv2T.assign(self.b_conv2),self.W_conv3T.assign(self.W_conv3),self.b_conv3T.assign(self.b_conv3),self.W_fc1T.assign(self.W_fc1),self.b_fc1T.assign(self.b_fc1),self.W_fc2T.assign(self.W_fc2),self.b_fc2T.assign(self.b_fc2)]</span><br><span class="line"></span><br><span class="line">self.createTrainingMethod()</span><br><span class="line"></span><br><span class="line"><span class="comment"># saving and loading networks</span></span><br><span class="line">self.saver = tf.train.Saver()</span><br><span class="line">self.session = tf.InteractiveSession()</span><br><span class="line">self.session.run(tf.initialize_all_variables())</span><br><span class="line">checkpoint = tf.train.get_checkpoint_state(<span class="string">"saved_networks"</span>)</span><br><span class="line"><span class="keyword">if</span> checkpoint <span class="keyword">and</span> checkpoint.model_checkpoint_path:</span><br><span class="line">self.saver.restore(self.session, checkpoint.model_checkpoint_path)</span><br><span class="line"><span class="built_in">print</span> <span class="string">"Successfully loaded:"</span>, checkpoint.model_checkpoint_path</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"><span class="built_in">print</span> <span class="string">"Could not find old network weights"</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">createQNetwork</span>(<span class="params">self</span>):</span><br><span class="line"><span class="comment"># network weights</span></span><br><span class="line">W_conv1 = self.weight_variable([<span class="number">8</span>,<span class="number">8</span>,<span class="number">4</span>,<span class="number">32</span>])</span><br><span class="line">b_conv1 = self.bias_variable([<span class="number">32</span>])</span><br><span class="line"></span><br><span class="line">W_conv2 = self.weight_variable([<span class="number">4</span>,<span class="number">4</span>,<span class="number">32</span>,<span class="number">64</span>])</span><br><span class="line">b_conv2 = self.bias_variable([<span class="number">64</span>])</span><br><span class="line"></span><br><span class="line">W_conv3 = self.weight_variable([<span class="number">3</span>,<span class="number">3</span>,<span class="number">64</span>,<span class="number">64</span>])</span><br><span class="line">b_conv3 = self.bias_variable([<span class="number">64</span>])</span><br><span class="line"></span><br><span class="line">W_fc1 = self.weight_variable([<span class="number">1600</span>,<span class="number">512</span>])</span><br><span class="line">b_fc1 = self.bias_variable([<span class="number">512</span>])</span><br><span class="line"></span><br><span class="line">W_fc2 = self.weight_variable([<span class="number">512</span>,self.actions])</span><br><span class="line">b_fc2 = self.bias_variable([self.actions])</span><br><span class="line"></span><br><span class="line"><span class="comment"># input layer</span></span><br><span class="line"></span><br><span class="line">stateInput = tf.placeholder(<span class="string">"float"</span>,[<span class="literal">None</span>,<span class="number">80</span>,<span class="number">80</span>,<span class="number">4</span>])</span><br><span class="line"></span><br><span class="line"><span class="comment"># hidden layers</span></span><br><span class="line">h_conv1 = tf.nn.relu(self.conv2d(stateInput,W_conv1,<span class="number">4</span>) + b_conv1)</span><br><span class="line">h_pool1 = self.max_pool_2x2(h_conv1)</span><br><span class="line"></span><br><span class="line">h_conv2 = tf.nn.relu(self.conv2d(h_pool1,W_conv2,<span class="number">2</span>) + b_conv2)</span><br><span class="line"></span><br><span class="line">h_conv3 = tf.nn.relu(self.conv2d(h_conv2,W_conv3,<span class="number">1</span>) + b_conv3)</span><br><span class="line"></span><br><span class="line">h_conv3_flat = tf.reshape(h_conv3,[-<span class="number">1</span>,<span class="number">1600</span>])</span><br><span class="line">h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat,W_fc1) + b_fc1)</span><br><span class="line"></span><br><span class="line"><span class="comment"># Q Value layer</span></span><br><span class="line">QValue = tf.matmul(h_fc1,W_fc2) + b_fc2</span><br><span class="line"></span><br><span class="line"><span class="keyword">return</span> stateInput,QValue,W_conv1,b_conv1,W_conv2,b_conv2,W_conv3,b_conv3,W_fc1,b_fc1,W_fc2,b_fc2</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">copyTargetQNetwork</span>(<span class="params">self</span>):</span><br><span class="line">self.session.run(self.copyTargetQNetworkOperation)</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">createTrainingMethod</span>(<span class="params">self</span>):</span><br><span class="line">self.actionInput = tf.placeholder(<span class="string">"float"</span>,[<span class="literal">None</span>,self.actions])</span><br><span class="line">self.yInput = tf.placeholder(<span class="string">"float"</span>, [<span class="literal">None</span>])</span><br><span class="line">Q_Action = tf.reduce_sum(tf.mul(self.QValue, self.actionInput), reduction_indices = <span class="number">1</span>)</span><br><span class="line">self.cost = tf.reduce_mean(tf.square(self.yInput - Q_Action))</span><br><span class="line">self.trainStep = tf.train.AdamOptimizer(<span class="number">1e-6</span>).minimize(self.cost)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">trainQNetwork</span>(<span class="params">self</span>):</span><br><span class="line"><span class="comment"># Step 1: obtain random minibatch from replay memory</span></span><br><span class="line">minibatch = random.sample(self.replayMemory,BATCH_SIZE)</span><br><span class="line">state_batch = [data[<span class="number">0</span>] <span class="keyword">for</span> data <span class="keyword">in</span> minibatch]</span><br><span class="line">action_batch = [data[<span class="number">1</span>] <span class="keyword">for</span> data <span class="keyword">in</span> minibatch]</span><br><span class="line">reward_batch = [data[<span class="number">2</span>] <span class="keyword">for</span> data <span class="keyword">in</span> minibatch]</span><br><span class="line">nextState_batch = [data[<span class="number">3</span>] <span class="keyword">for</span> data <span class="keyword">in</span> minibatch]</span><br><span class="line"></span><br><span class="line"><span class="comment"># Step 2: calculate y</span></span><br><span class="line">y_batch = []</span><br><span class="line">QValue_batch = self.QValueT.<span class="built_in">eval</span>(feed_dict={self.stateInputT:nextState_batch})</span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">0</span>,BATCH_SIZE):</span><br><span class="line">terminal = minibatch[i][<span class="number">4</span>]</span><br><span class="line"><span class="keyword">if</span> terminal:</span><br><span class="line">y_batch.append(reward_batch[i])</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line">y_batch.append(reward_batch[i] + GAMMA * np.<span class="built_in">max</span>(QValue_batch[i]))</span><br><span class="line"></span><br><span class="line">self.trainStep.run(feed_dict={</span><br><span class="line">self.yInput : y_batch,</span><br><span class="line">self.actionInput : action_batch,</span><br><span class="line">self.stateInput : state_batch</span><br><span class="line">})</span><br><span class="line"></span><br><span class="line"><span class="comment"># save network every 100000 iteration</span></span><br><span class="line"><span class="keyword">if</span> self.timeStep % <span class="number">10000</span> == <span class="number">0</span>:</span><br><span class="line">self.saver.save(self.session, <span class="string">'saved_networks/'</span> + <span class="string">'network'</span> + <span class="string">'-dqn'</span>, global_step = self.timeStep)</span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> self.timeStep % UPDATE_TIME == <span class="number">0</span>:</span><br><span class="line">self.copyTargetQNetwork()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">setPerception</span>(<span class="params">self,nextObservation,action,reward,terminal</span>):</span><br><span class="line"><span class="comment">#newState = np.append(nextObservation,self.currentState[:,:,1:],axis = 2)</span></span><br><span class="line">newState = np.append(self.currentState[:,:,<span class="number">1</span>:],nextObservation,axis = <span class="number">2</span>)</span><br><span class="line">self.replayMemory.append((self.currentState,action,reward,newState,terminal))</span><br><span class="line"><span class="keyword">if</span> <span class="built_in">len</span>(self.replayMemory) > REPLAY_MEMORY:</span><br><span class="line">self.replayMemory.popleft()</span><br><span class="line"><span class="keyword">if</span> self.timeStep > OBSERVE:</span><br><span class="line"><span class="comment"># Train the network</span></span><br><span class="line">self.trainQNetwork()</span><br><span class="line"></span><br><span class="line"><span class="comment"># print info</span></span><br><span class="line">state = <span class="string">""</span></span><br><span class="line"><span class="keyword">if</span> self.timeStep <= OBSERVE:</span><br><span class="line">state = <span class="string">"observe"</span></span><br><span class="line"><span class="keyword">elif</span> self.timeStep > OBSERVE <span class="keyword">and</span> self.timeStep <= OBSERVE + EXPLORE:</span><br><span class="line">state = <span class="string">"explore"</span></span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line">state = <span class="string">"train"</span></span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span> <span class="string">"TIMESTEP"</span>, self.timeStep, <span class="string">"/ STATE"</span>, state, \</span><br><span class="line"> <span class="string">"/ EPSILON"</span>, self.epsilon</span><br><span class="line"></span><br><span class="line">self.currentState = newState</span><br><span class="line">self.timeStep += <span class="number">1</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">getAction</span>(<span class="params">self</span>):</span><br><span class="line">QValue = self.QValue.<span class="built_in">eval</span>(feed_dict= {self.stateInput:[self.currentState]})[<span class="number">0</span>]</span><br><span class="line">action = np.zeros(self.actions)</span><br><span class="line">action_index = <span class="number">0</span></span><br><span class="line"><span class="keyword">if</span> self.timeStep % FRAME_PER_ACTION == <span class="number">0</span>:</span><br><span class="line"><span class="keyword">if</span> random.random() <= self.epsilon:</span><br><span class="line">action_index = random.randrange(self.actions)</span><br><span class="line">action[action_index] = <span class="number">1</span></span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line">action_index = np.argmax(QValue)</span><br><span class="line">action[action_index] = <span class="number">1</span></span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line">action[<span class="number">0</span>] = <span class="number">1</span> <span class="comment"># do nothing</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># change episilon</span></span><br><span class="line"><span class="keyword">if</span> self.epsilon > FINAL_EPSILON <span class="keyword">and</span> self.timeStep > OBSERVE:</span><br><span class="line">self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/EXPLORE</span><br><span class="line"></span><br><span class="line"><span class="keyword">return</span> action</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">setInitState</span>(<span class="params">self,observation</span>):</span><br><span class="line">self.currentState = np.stack((observation, observation, observation, observation), axis = <span class="number">2</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">weight_variable</span>(<span class="params">self,shape</span>):</span><br><span class="line">initial = tf.truncated_normal(shape, stddev = <span class="number">0.01</span>)</span><br><span class="line"><span class="keyword">return</span> tf.Variable(initial)</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">bias_variable</span>(<span class="params">self,shape</span>):</span><br><span class="line">initial = tf.constant(<span class="number">0.01</span>, shape = shape)</span><br><span class="line"><span class="keyword">return</span> tf.Variable(initial)</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">conv2d</span>(<span class="params">self,x, W, stride</span>):</span><br><span class="line"><span class="keyword">return</span> tf.nn.conv2d(x, W, strides = [<span class="number">1</span>, stride, stride, <span class="number">1</span>], padding = <span class="string">"SAME"</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">max_pool_2x2</span>(<span class="params">self,x</span>):</span><br><span class="line"><span class="keyword">return</span> tf.nn.max_pool(x, ksize = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">1</span>], strides = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">1</span>], padding = <span class="string">"SAME"</span>)</span><br></pre></td></tr></table></figure><p>下面是使用上面的DQN让agent玩Gym的Breakout-v0游戏。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># -------------------------</span></span><br><span class="line"><span class="comment"># Project: Deep Q-Learning on Breakout-v0</span></span><br><span class="line"><span class="comment"># Author: Houjiang Chen</span></span><br><span class="line"><span class="comment"># Date: 2017.4.25</span></span><br><span class="line"><span class="comment"># -------------------------</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> cv2</span><br><span class="line"><span class="keyword">import</span> gym</span><br><span class="line"><span class="keyword">from</span> BrainDQN_Nature <span class="keyword">import</span> BrainDQN</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"></span><br><span class="line"><span class="comment"># preprocess raw image to 80*80 gray image</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">preprocess</span>(<span class="params">observation</span>):</span><br><span class="line"> observation = cv2.cvtColor(cv2.resize(observation, (<span class="number">80</span>, <span class="number">80</span>)), cv2.COLOR_BGR2GRAY)</span><br><span class="line"> <span class="comment">#ret, observation = cv2.threshold(observation, 1, 255, cv2.THRESH_BINARY)</span></span><br><span class="line"> <span class="keyword">return</span> np.reshape(observation, (<span class="number">80</span>, <span class="number">80</span>, <span class="number">1</span>))</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">play</span>():</span><br><span class="line"> env = gym.make(<span class="string">'Breakout-v0'</span>)</span><br><span class="line"> actions = env.action_space.n</span><br><span class="line"></span><br><span class="line"> <span class="comment"># init BrainDQN</span></span><br><span class="line"> brain = BrainDQN(actions)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">while</span> <span class="number">1</span>:</span><br><span class="line"> state = env.reset()</span><br><span class="line"> state = cv2.cvtColor(cv2.resize(state, (<span class="number">80</span>, <span class="number">80</span>)), cv2.COLOR_BGR2GRAY)</span><br><span class="line"> <span class="comment">#ret, state = cv2.threshold(state, 1, 255, cv2.THRESH_BINARY)</span></span><br><span class="line"> brain.setInitState(state)</span><br><span class="line"> <span class="keyword">while</span> <span class="number">1</span>:</span><br><span class="line"> action = brain.getAction()</span><br><span class="line"> state, reward, terminal, _ = env.step(np.argmax(action))</span><br><span class="line"> env.render()</span><br><span class="line"> <span class="keyword">if</span> terminal:</span><br><span class="line"> <span class="keyword">break</span></span><br><span class="line"> state = preprocess(state)</span><br><span class="line"> brain.setPerception(state, action, reward, terminal)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">main</span>():</span><br><span class="line"> play()</span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> __name__ == <span class="string">'__main__'</span>:</span><br><span class="line"> main()</span><br></pre></td></tr></table></figure><h2 id="参考资料">参考资料</h2><p>1、Reinforcement Learning: An Introduction, Richard S. Sutton andAndrew G. Barto,2012<br />2、Playing Atari with Deep Reinforcement Learning,DeepMindTechnologies,Arxiv 2013.12<br />3、Human-level control through deep reinforcement learning,DeepMindTechnologies,Nature 2015.02<br />4、DeepMind官网https://deepmind.com/blog/deep-reinforcement-learning<br />5、https://www.nervanasys.com/demystifying-deep-reinforcement-learning<br />6、http://www.cnblogs.com/jinxulin/p/3511298.html<br />7、Introduction to Reinforcement Learning,David Silver</p>]]></content>
<summary type="html"><h2 id="dqn">DQN</h2>
<p>前面我们讲到TD算法结合了动态规划和蒙特卡洛算法的优点,不依赖具体的环境模型,并且更新时采用滑动平均的方式,因此单步就能更新,而不需要生成整个episode,在非episode情况下仍然适用。TD算法又分为on
policy的sarsa算法和off policy的Q learning算法,其中Q
learning算法直接使用下一状态的最大动作值函数进行更新,加快了算法收敛速度,因此Q
learning算法在实际应用中更加普遍。</p></summary>
<category term="reinforcement learning" scheme="https://hjchen2.github.io/categories/reinforcement-learning/"/>
<category term="reinforcement learning" scheme="https://hjchen2.github.io/tags/reinforcement-learning/"/>
<category term="machine learning" scheme="https://hjchen2.github.io/tags/machine-learning/"/>
</entry>
<entry>
<title>值函数的贝尔曼公式推导</title>
<link href="https://hjchen2.github.io/2017/04/10/%E5%80%BC%E5%87%BD%E6%95%B0%E7%9A%84%E8%B4%9D%E5%B0%94%E6%9B%BC%E5%85%AC%E5%BC%8F%E6%8E%A8%E5%AF%BC/"/>
<id>https://hjchen2.github.io/2017/04/10/%E5%80%BC%E5%87%BD%E6%95%B0%E7%9A%84%E8%B4%9D%E5%B0%94%E6%9B%BC%E5%85%AC%E5%BC%8F%E6%8E%A8%E5%AF%BC/</id>
<published>2017-04-10T04:31:08.000Z</published>
<updated>2023-01-03T14:06:55.430Z</updated>
<content type="html"><![CDATA[<p>下面的推导过程中第2步和第5步两次用到重期望公式: <spanclass="math inline">\(\bf{EX}=\bf{E\left(E\left[X\midY\right]\right)}\)</span>。</p><span id="more"></span><p><span class="math display">\[\begin{split}\upsilon_{\pi}(s)&={\bf{E_{\pi}}}\left[G_{t}\mid{S_{t}=s}\right] \\&={\bf{E_{\pi}}}\left({\bf{E_{\pi}}}\left[G_t\midS_t=s,A_t\right]\right) \\&={\bf{E_{\pi}}}\left[\sum_a\pi(a|s)G_t\mid S_t=s,A_t=a\right] \\&=\sum_a\pi(a|s){\bf{E_{\pi}}}\left[G_t\mid S_t=s,A_t=a\right] \\&=\sum_a\pi(a|s){\bf{E_{\pi}}}\left({\bf{E_{\pi}}}\left[G_t\midS_t=s,A_t=a,S_{t+1}\right]\right) \\&=\sum_a\pi(a|s){\bf{E_{\pi}}}\left[\sum_{s^{'}}p(s^{'}\mids,a)G_t\mid S_t=s,A_t=a,S_{t+1}=s^{'}\right] \\&=\sum_a\pi(a|s)\sum_{s^{'}}p(s^{'}\mids,a){\bf{E_{\pi}}}\left[G_t\mid S_t=s,A_t=a,S_{t+1}=s^{'}\right] \\&=\sum_{a}\pi(a\mid{s})\sum_{s^{'}}p(s^{'}\mid s,a){\bfE}_{\pi}\left[R_{t+1}+\gamma\sum_{k=0}^{\infty}\gamma^{k}R_{t+k+2}\mid{S_{t}=s,A_{t}=a,S_{t+1}=s^{'}}\right]\\&=\sum_{a}\pi(a\mid{s})\sum_{s^{'}}p(s^{'}\mid{s,a})\left[r(s,a,s^{'})+\gamma{\bfE}_{\pi}\left[\sum_{k=0}^{\infty}\gamma^{k}R_{t+k+2}\mid{S_{t+1}=s^{'}}\right]\right]\\&=\sum_{a}\pi(a\mid{s})\sum_{s^{'}}p(s^{'}\mid{s,a})\left[r(s,a,s^{'})+\gamma\upsilon_{\pi}(s^{'})\right]\end{split}\]</span></p>]]></content>
<summary type="html"><p>下面的推导过程中第2步和第5步两次用到重期望公式: <span
class="math inline">\(\bf{EX}=\bf{E\left(E\left[X\mid
Y\right]\right)}\)</span>。</p></summary>
<category term="reinforcement learning" scheme="https://hjchen2.github.io/categories/reinforcement-learning/"/>
<category term="reinforcement learning" scheme="https://hjchen2.github.io/tags/reinforcement-learning/"/>
<category term="machine learning,贝尔曼公式推导" scheme="https://hjchen2.github.io/tags/machine-learning%EF%BC%8C%E8%B4%9D%E5%B0%94%E6%9B%BC%E5%85%AC%E5%BC%8F%E6%8E%A8%E5%AF%BC/"/>
</entry>
<entry>
<title>强化学习(一)</title>
<link href="https://hjchen2.github.io/2017/03/27/%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0%EF%BC%88%E4%B8%80%EF%BC%89/"/>
<id>https://hjchen2.github.io/2017/03/27/%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0%EF%BC%88%E4%B8%80%EF%BC%89/</id>
<published>2017-03-27T04:31:08.000Z</published>
<updated>2023-02-07T02:39:23.375Z</updated>
<content type="html"><![CDATA[<h2 id="前言">前言</h2><p>近几年,由于DeepMind成功地将强化学习(reinforcementlearning)运用在AlphaGo上,机器首次在复杂任务上取得了超过人类的表现,使得强化学习成为目前机器学习研究的前沿方向之一。强化学习由来已久,Sutton等在1979年就已经开始研究强化学习,1998年出版了强化学习介绍一书,并于2012年发布第二版,本文前几部分内容主要参考该书。</p><span id="more"></span><p>强化学习最早主要用于智能控制领域,比如机器人控制、电梯调度、电信通讯等,如今已经在自动驾驶、NLP、内容推荐<sup>[4]</sup>和语音交互领域都有相关的应用。2013年底DeepMind发表文章PlayingAtari with Deep ReinforcementLearning,首次成功地将深度学习运用到强化学习任务上,通过无监督学习实现从纯图像输入来玩Atari2600游戏的效果。而后DeepMind逐渐改进算法,使得DQN在Atari几乎一半的游戏中超过人类水平,以至2016年AlphaGo和无人车的出现,人们惊奇地发现人工智能即将颠覆我们的生活,甚至有人评论说传统的深度学习已经可以很好地感知理解了,强化学习可以利用这些感知生成策略,因而可以创造更高的机器智能。</p><p>下面是DeepMind使用DQN让机器学习玩Atari 2600游戏的视频。<iframe width="895" height="503" src="https://www.youtube.com/embed/TmPfTpjtdgg" frameborder="0" allowfullscreen></iframe></p><h2 id="什么是强化学习">什么是强化学习</h2><p>Reinforcement learning is learning what to do—how to map situationsto actions—so as to maximize a numerical rewardsignal<sup>[1]</sup>.</p><p>强化学习研究的是智能体agent与环境之间交互的任务,也就是让agent像人类一样通过试错,不断地学习在不同的环境下做出最优的动作,而不是有监督地直接告诉agent在什么环境下应该做出什么动作。在这里我们需要引入回报(reward)这个概念,回报是执行一个动作或一系列动作后得到的奖励,比如在游戏超级玛丽中,向上跳可以获得一个金币,也就是回报值为1,而不跳时回报就是0。回报又分为立即回报和长期回报,立即回报指的是执行当前动作后能立刻获得的奖励,但很多时候我们执行一个动作后并不能立即得到回报,而是在游戏结束时才能返回一个回报值,这就是长期回报。强化学习唯一的准则就是学习通过一序列的最优动作,获得最大的长期回报。比较有挑战性的是,任一状态下做出的动作不仅影响当前状态的立即回报,而且也会影响到下一个状态,因此也就会影响整个执行过程的回报。</p><p>因此,强化学习和监督学习的区别主要有以下两点<sup>[6]</sup>:</p><ol type="1"><li>强化学习是试错学习(Trail-and-error),由于没有直接的指导信息,智能体要以不断与环境进行交互,通过试错的方式来获得最佳策略。</li><li>延迟回报,强化学习的指导信息很少,而且往往是在事后(最后一个状态)才给出的,这就导致了一个问题,就是获得正回报或者负回报以后,如何将回报分配给前面的状态。</li></ol><h2 id="问题描述与mdp">问题描述与MDP</h2><p>前面已经提到强化学习是尝试并发现回报最大动作的过程,下面就具体来描述一下这个过程。首先考虑一个问题,一个之前完全没有接触过国际象棋的小白怎样和一个专业棋手对弈。刚开始小白对棋面并没有任何概念,只能随机下,但假设双方每一轮下完后都会得到立即回报,比如吃子回报为1,被吃回报为-1,其他回报为0。可以想象一开始小白会输得很惨,但如果小白很聪明,随着不断地尝试小白不仅理解了下棋的规则,并且知道在什么棋面下做出什么动作可以吃更多的棋子。在这里我们将小白作为我们的智能体agent,棋面就是状态,下棋就是agent根据当前状态做出的动作,每个动作执行完后都会引起状态改变,如果状态的改变只与前一个状态和当前的动作有关,而与之前的状态和动作无关(即满足马尔可夫性),那么整个过程可以用马尔可夫决策过程(MarkovDecisionProcesses)来描述,而Sutton在书中直接将满足马尔可夫性的强化学习任务定义为马尔可夫决策过程,并将状态和动作都是有限空间的MDP定义为有限马尔可夫决策过程(finiteMDP)。</p><p>下面引入一些定义<sup>[1]</sup>:马尔可夫决策过程是一个agent与环境交互的过程,因此有一个离散的时间序列,<spanclass="math inline">\(t=0,1,2,3,...\)</span>,在每一个时刻<spanclass="math inline">\(t\)</span>,agent都会接收一个用来表示环境的状态<spanclass="math inline">\(S_{t}\in\bf{S}\)</span>,其中<spanclass="math inline">\(\bf{S}\)</span>表示所有可能状态的集合,并且在状态的基础上选择一个动作<spanclass="math inline">\(A_{t}\in{\bf{A}}(S_{t})\)</span>,其中<spanclass="math inline">\({\bf{A}}(S_{t})\)</span>表示在状态<spanclass="math inline">\(S_{t}\)</span>时所有可能采取的动作的集合,在<spanclass="math inline">\(t\)</span>时刻agent采取一个动作后都会收到一个回报值<spanclass="math inline">\(R_{t+1}\in\bf{R}\)</span>,然后接收一个新状态<spanclass="math inline">\(S_{t+1}\)</span>。下图为整个过程的示意图。</p><div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/pictures/19e9ab8239fe6be8a413990a592b83c2.png?raw=true width=600></div><p></br>在任意时刻和状态下,agent都可以选择一个动作,选择的依据就是我们说的策略—即状态到动作的映射<spanclass="math inline">\(\pi(a\mid{s})\)</span>,而一个使得在任意时刻和状态下的长期回报都是最大的策略是我们最终需要得到的。所谓长期回报我们可以用每个时刻的立即回报来表示:</p><p><spanclass="math display">\[G_{t}=R_{t+1}+R_{t+2}+R_{t+3}+...=\sum_{k=t+1}^{\infty}R_{k}\tag{1.1}\]</span></p><p>但实际上我们一般会用下面更通用的公式来代替:</p><p><spanclass="math display">\[G_{t}=R_{t+1}+\gamma{R_{t+2}}+\gamma^2{R_{t+3}}+...+\gamma^{T-t-1}{R_{T}}=\sum_{k=0}^{T-t-1}\gamma^{k}R_{t+k+1}\tag{1.2}\]</span></p><p>其中<spanclass="math inline">\(\gamma\in[0,1]\)</span>称为回报折扣因子,表明了未来的回报相对于当前回报的重要程度。<spanclass="math inline">\(\gamma=0\)</span>时,相当于只考虑立即回报不考虑长期回报,<spanclass="math inline">\(\gamma=1\)</span>时,将长期回报和立即回报看得同等重要。<spanclass="math inline">\(T\in[1,\infty]\)</span>表示完成一次实验过程的总步数,<spanclass="math inline">\(T=\infty\)</span>和<spanclass="math inline">\(\gamma=1\)</span>不能同时满足,否则长期回报将无法收敛。特别地,我们将一次有限步数的实验称作一个单独的episodes,也就是经过有限步数后最终会接收一个终止状态,这一类的任务也叫做episodictasks。下面讨论的强化学习任务都是有限MDP的episodic tasks。</p><h3 id="马尔可夫决策过程">马尔可夫决策过程</h3><p>一个有限马尔可夫决策过程由一个四元组构成 <spanclass="math inline">\(M=({\bf{S}}, {\bf{A}}, {\bf{P}},{\bf{R}})\)</span><sup>[6]</sup>。如上所述,<spanclass="math inline">\(\bf{S}\)</span>表示状态集空间,<spanclass="math inline">\({\bf{A}}\)</span>表示动作集空间,<spanclass="math inline">\({\bf{P}}\)</span>表示状态转移概率矩阵,<spanclass="math inline">\({\bf{R}}\)</span>表示期望回报值。</p><p>在MDP中给定任何一个状态<spanclass="math inline">\(s\in\bf{S}\)</span>和动作<spanclass="math inline">\(a\in\bf{A}\)</span>,都会以某个概率转移到下一个状态<spanclass="math inline">\(s^{'}\)</span>,这个概率为<spanclass="math inline">\(p(s^{'}\mid s,a)={\bf{Pr}}\{S_{t+1}=s^{'}\mid S_{t}=s,A_{t}=a\}\in\bf{P}\)</span>,并获得下一个回报的期望值为<spanclass="math inline">\(r(s,a,s^{'})={\bf{E}}\left[R_{t+1}\mid{S_{t}=s,A_{t}=a,S_{t+1}=s^{'}}\right]\in\bf{R}\)</span>。</p><h3 id="值函数及贝尔曼公式">值函数及贝尔曼公式</h3><p>增强学习的最终结果是找到一个环境到动作的映射—即策略<spanclass="math inline">\(\pi(a\mid{s})\)</span>。如果一个策略只考虑立即回报,那么很可能就会掉入眼前陷阱。比如说有一个岔路口,往左回报是100,往右回报是10,如果策略只考虑立即回报,那肯定是往左,但往左走的下一次回报只有10,而往右走的下一次回报有200,可以看到这个策略并不是最优的策略,此外增强学习又往往有具有延迟回报的特点,在很多情况下的动作并不会产生立即回报,但这一系列动作的累积效果又的确会导致后续回报的产生,因此立即回报并不能说明策略的好坏。在几乎所有的强化学习理论中都会定义值函数来表示给定策略下期望的未来回报,并将值函数作为评估学习效果的指标。</p><p>值函数有多种定义,目前常见的是将值函数直接定义为未来回报的期望:</p><p><span class="math display">\[\upsilon_{\pi}(s)={\bf{E_{\pi}}}\left[G_{t}\mid{S_{t}=s}\right]={\bf{E_{\pi}}}\left[\sum_{k=0}^{\infty}\gamma^{k}R_{t+k+1}\mid{S_{t}=s}\right]\tag{2.1}\]</span></p><p>上面表示的是在某个策略<spanclass="math inline">\(\pi\)</span>下,当环境处于状态<spanclass="math inline">\(s\)</span>时未来回报的期望,因此又叫做状态值函数(state-valuefunction forpolicy),只跟当前状态有关。同样,我们也可以定义动作值函数(action-valuefunction for policy),如下:</p><p><span class="math display">\[\begin{split}q_{\pi}(s,a)&={\bf{E_{\pi}}}\left[G_{t}\mid{S_{t}=s,A_{t}=a}\right]\\&={\bf{E_{\pi}}}\left[\sum_{k=0}^{\infty}\gamma^{k}R_{t+k+1}\mid{S_{t}=s,A_{t}=a}\right]\end{split}\tag{2.2}\]</span></p><p>动作值函数表示在某个策略<spanclass="math inline">\(\pi\)</span>下,当环境处于状态<spanclass="math inline">\(s\)</span>时采取动作<spanclass="math inline">\(a\)</span>的未来回报的期望。可以看到动作值函数与状态值函数唯一的不同是动作值函数不仅指定了一个初始状态,而且也指定了初始动作,而状态值函数的初始动作是根据策略产生的。由于在MDP中,给定状态<spanclass="math inline">\(s\)</span>,agent根据策略选择动作<spanclass="math inline">\(a\)</span>,下个时刻将以概率<spanclass="math inline">\(p(s^{'}\mid{s,a})\)</span>转移到状态<spanclass="math inline">\(s^{'}\)</span>,因此值函数又可以改写成如下形式:</p><p><span class="math display">\[\begin{split}\upsilon_{\pi}(s)&={\bf{E_{\pi}}}\left[G_{t}\mid{S_{t}=s}\right]\\&={\bf{E_{\pi}}}\left[\sum_{k=0}^{\infty}\gamma^{k}R_{t+k+1}\mid{S_{t}=s}\right]\\&={\bf{E_{\pi}}}\left[R_{t+1}+\gamma\sum_{k=0}^{\infty}\gamma^{k}R_{t+k+2}\mid{S_{t}=s}\right]\\&=\sum_{a}\pi(a\mid{s})\cdot{\bfE}_{\pi}\left[R_{t+1}+\gamma\sum_{k=0}^{\infty}\gamma^{k}R_{t+k+2}\mid{S_{t}=s,A_{t}}\right]\\&=\sum_{a}\pi(a\mid{s})\sum_{s^{'}}p(s^{'}\mid{s,a})\left[r(s,a,s^{'})+\gamma{\bfE}_{\pi}\left[\sum_{k=0}^{\infty}\gamma^{k}R_{t+k+2}\mid{S_{t+1}=s^{'}}\right]\right]\\&=\sum_{a}\pi(a\mid{s})\sum_{s^{'}}p(s^{'}\mid{s,a})\left[r(s,a,s^{'})+\gamma\upsilon_{\pi}(s^{'})\right]\end{split}\tag{2.3}\]</span></p><p>也就是说在策略<spanclass="math inline">\(\pi\)</span>下当前状态的值函数可以通过下一个状态的值函数来迭代求解,这个公式被称为<spanclass="math inline">\(\upsilon_{\pi}\)</span>的贝尔曼公式(Bellmanequation for <span class="math inline">\(\upsilon_{\pi}\)</span>)。</p><p>同样,动作值函数也可以写成相似的形式:</p><p><span class="math display">\[\begin{split}q_{\pi}(s,a)&={\bf{E_{\pi}}}\left[G_{t}\mid{S_{t}=s,A_{t}=a}\right]\\&={\bf{E_{\pi}}}\left[R_{t+1}+\gamma\sum_{k=0}^{\infty}\gamma^{k}R_{t+k+2}\mid{S_{t}=s,A_{t}=a}\right]\\&=\sum_{s^{'}}p(s^{'}\mid{s,a})\left[r(s,a,s^{'})+\gamma\upsilon_{\pi}(s^{'})\right]\end{split}\tag{2.4}\]</span></p><p><span class="math inline">\(\upsilon_{\pi}(s)\)</span>也可以用<spanclass="math inline">\(q_{\pi}(s,a)\)</span>来表示:</p><p><spanclass="math display">\[\upsilon_{\pi}(s)=\sum_{a}\pi(a\mid{s})q_{\pi}(s,a)\tag{2.5}\]</span></p>下面是迭代计算<spanclass="math inline">\(\upsilon_{\pi}(s)\)</span>和<spanclass="math inline">\(q_{\pi}(s,a)\)</span>的图解<sup>[1]</sup>,可以与上述公式对照理解。<div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/pictures/205fd62a7177a841cdc79585cf1ed6ae.png?raw=true width=600></div><h3 id="最优值函数及贝尔曼最优公式">最优值函数及贝尔曼最优公式</h3><p>上面所说的值函数都是未来回报的期望值,而我们需要得到的最优策略必然是使得任意时刻未来回报的期望值都是最大的,也就是说我们的优化目标可以表示为:</p><p><spanclass="math display">\[\pi_{*}=\mathop{\arg\max}_{\mathbf{\pi}}\upsilon_{\pi}(s)\tag{2.6}\]</span></p><p>当然最优策略可能不止一个,但这些最优策略都有一个共同的特点,就是它们共享同样的状态值函数,这个状态值函数叫做最优状态值函数(optimalstate-value function),用<spanclass="math inline">\(\upsilon_{*}\)</span>来表示。对于所有的<spanclass="math inline">\(s\in\bf{S}\)</span>,</p><p><spanclass="math display">\[\upsilon_{*}(s)=\max_{\mathbf{\pi}}\upsilon_{\pi}(s)\tag{2.7}\]</span></p><p>最优策略同样也共享相同的动作值函数(optimal action-valuefunction),用<spanclass="math inline">\(q_{*}\)</span>来表示。对于所有的<spanclass="math inline">\(s\in\bf{S}\)</span>,<spanclass="math inline">\(a\in{\bf{A}}(s)\)</span>,</p><p><spanclass="math display">\[q_{*}(s,a)=\max_{\mathbf{\pi}}q_{\pi}(s,a)\tag{2.8}\]</span></p><p>回顾一下上面动作值函数的改写公式(2.4),<spanclass="math inline">\(q_{\pi}(s,a)=\sum_{s^{'}}p(s^{'}\mid{s,a})\left[r(s,a,s^{'})+\gamma\upsilon_{\pi}(s^{'})\right]\)</span>,由于动作值函数表示的是给定初始动作,后面的动作遵循策略<spanclass="math inline">\(\pi\)</span>,因此最优动作值函数后面的动作应当遵循最优策略<spanclass="math inline">\(\pi_{*}\)</span>,不难得到下面的公式。 <spanclass="math display">\[q_{*}(s,a)=\sum_{s^{'}}p(s^{'}\mid{s,a})\left[r(s,a,s^{'})+\gamma\upsilon_{*}(s^{'})\right]\tag{2.9}\]</span></p><p>至此,最优值函数的形式已经给出了,现在我们继续回顾一下公式(2.5)的意义,<spanclass="math inline">\(\upsilon_{\pi}(s)\)</span>的值是<spanclass="math inline">\(q_{\pi}(s,a)\)</span>的期望,那么必然存在<spanclass="math inline">\(\upsilon_{\pi}(s)\leq \maxq_{\pi}(s,a)\)</span>。但对于最优策略来说,</p><p><span class="math display">\[\begin{split}\upsilon_{*}(s)&=\max_{\mathbf{a}} q_{*}(s,a) \\&=\max_{\mathbf{a}}\sum_{s^{'}}p(s^{'}\mid{s,a})\left[r(s,a,s^{'})+\gamma\upsilon_{*}(s^{'})\right]\end{split}\tag{2.10}\]</span></p><p><span class="math display">\[q_{*}(s,a)=\sum_{s^{'}}p(s^{'}\mid{s,a})\left[r(s,a,s^{'})+\gamma\max_{\mathbf{a^{'}}}q_{*}(s^{'},a^{'})\right]\tag{2.11}\]</span></p><p>与状态值函数的贝尔曼公式一样,最优状态值函数和最优动作值函数也可以表示成递归的形式,因此公式(2.10)和公式(2.11)又分别叫做状态值函数和动作值函数的贝尔曼最优公式(Bellmanoptimality equation)。因为没有<spanclass="math inline">\(\pi(a\mid{s})\)</span>,不需要根据策略生成动作,因此贝尔曼最优公式完全独立于策略,但如果我们已知<spanclass="math inline">\(\upsilon_{*}\)</span>或<spanclass="math inline">\(q_{*}\)</span>,都可以很容易地得到最优策略。</p><p>如果我们已知<spanclass="math inline">\(\upsilon_{*}\)</span>,而且在每一步都有多个动作可以选择,可以想到最优策略的<spanclass="math inline">\(\upsilon_{*}(s)\)</span>必然是满足贝尔曼最优公式的,因此至少有一个动作会满足公式中的最大化条件。任何一个采用上述动作并能够以非零概率转移到下一个状态的策略都是最优策略。我们可以把当前动作的选择看成是一个单步搜索(one-stepsearch)的问题,在某个状态下单步搜索结果最大的动作即最优动作,而每个状态下都采取最优动作的策略即最优策略。如果我们已知<spanclass="math inline">\(q_{*}\)</span>,那么只需要在每一步都选择使得<spanclass="math inline">\(q_{*}(s,a)\)</span>最大的动作,就可以得到一个最优策略。</p><p>贝尔曼公式与贝尔曼最优公式是MDP求解的基础,下面主要介绍几种MDP求解的方法。</p><h2 id="动态规划方法">动态规划方法</h2><p>动态规划(dynamicprogramming)指的是能够用来解决给定环境模型,计算最优策略的算法总称。典型的动态规划算法存在两个问题,一是需要依赖一个非常好的环境状态转移模型,二是计算的开销非常大,因此在增强学习中几乎不会直接用动态规划求解MDP,但动态规划理论还是非常重要的,因为后面的一些算法都是在动态规划的基础上,摆脱模型依赖并尽可能地减少计算量。</p><h3 id="策略估计">策略估计</h3><p>首先,我们考虑一下如果已知策略<spanclass="math inline">\(\pi\)</span>,如何来计算<spanclass="math inline">\(\upsilon_{\pi}\)</span>。这个问题被称作DP迭代中的策略估计(policyevaluation)。</p><p>先举一个例子,一个岔路口有向左和向右两个方向,向左回报为10,向右回报为100,我们没有任何先验知识,但我们需要估计站在路口的值函数,也就是估计当前状态的值函数,该如何来估计呢?首先我们将值函数初始化为0,然后进行大量的尝试,每次都以0.5的概率选择方向左,并获得回报10,以0.5的概率选择方向右,获得回报100。那么只要能将这两个方向都至少遍历一遍,就可以得到该状态的值函数<spanclass="math inline">\(\upsilon_{随机策略}=\frac{1}{N}\sum_{i=0}^{N}{0.5\cdotR_{i}}\)</span>,其中<spanclass="math inline">\(N\)</span>为实验的总次数。</p><p>同样,我们也是采用相似的方法迭代来进行策略估计的。首先将所有的<spanclass="math inline">\(\upsilon_{\pi}(s)\)</span>都初始化为0(或者任意值,但终止状态必须为0),然后采用如下公式更新所有状态<spanclass="math inline">\(s\)</span>的值函数。</p><p><span class="math display">\[\begin{split}\upsilon_{k+1}(s) &={\bf{E}}_{\pi}\left[R_{t+1}+\gamma\upsilon_{k}(S_{t+1})\mid S_{t}=s \right] \\&=\sum_{a}\pi(a\mid{s})\sum_{s^{'}}p(s^{'}\mid{s,a})\left[r(s,a,s^{'})+\gamma\upsilon_{k}(s^{'})\right]\end{split}\tag{3.1}\]</span></p><p>其中<spanclass="math inline">\(\upsilon_{k+1}(s)\)</span>表示在当前策略下第<spanclass="math inline">\(k+1\)</span>次迭代状态<spanclass="math inline">\(s\)</span>的值函数,<spanclass="math inline">\(\upsilon_{k}(s^{'})\)</span>表示在当前策略下第<spanclass="math inline">\(k\)</span>次迭代状态<spanclass="math inline">\(s^{'}\)</span>的值函数,该公式就是用上一次迭代计算得到的值函数来更新本次迭代的值函数。在具体操作时,又有两种更新方法<sup>[6]</sup>,</p><ul><li>将第<spanclass="math inline">\(k\)</span>次迭代计算得到的所有状态值函数<spanclass="math inline">\(\left[\upsilon_{k}(s_{1}),\upsilon_{k}(s_{2}),\upsilon_{k}(s_{3}),...\right]\)</span>保存在一个数组中,第<spanclass="math inline">\(k+1\)</span>次迭代的<spanclass="math inline">\(\upsilon_{k+1}(s)\)</span>使用第<spanclass="math inline">\(k\)</span>次的<spanclass="math inline">\(\upsilon_{k}(s^{'})\)</span>进行更新,更新后的值保存在另一个数组中。</li><li>仅用一个数组来保存各状态的值函数,每次更新后就将原来的值覆盖。这样在第<spanclass="math inline">\(k+1\)</span>次迭代时<spanclass="math inline">\(\upsilon_{k+1}(s)\)</span>就有可能使用的是第<spanclass="math inline">\(k+1\)</span>次更新后的<spanclass="math inline">\(\upsilon_{k+1}(s^{'})\)</span>,这样可以及时地利用更新的值函数,收敛更快。</li></ul>下面为整个策略估计的算法过程:<div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/pictures/19f3246af64a89e7bf38a4d53ea26819.png?raw=true width=560></div><h3 id="策略改进">策略改进</h3><p>策略估计是为了计算当前策略下各状态的值函数,那得到值函数又有什么用呢?首先我们可以用来比较两个策略的好坏,如果状态值函数是已知的,那么就可以根据公式(2.4)计算动作值函数,如果一个策略<spanclass="math inline">\(\pi\)</span>的所有动作值函数都大于另一个策略<spanclass="math inline">\(\pi^{'}\)</span>,那么可以认为策略<spanclass="math inline">\(\pi\)</span>比策略<spanclass="math inline">\(\pi^{'}\)</span>更好。其次,最主要的用处是可以用来进行策略改进(policyimprovement)。</p><p>仍然是上面岔路口的例子,但是假设无论向左还是向右,下一个路口都是唯一且相同的。起初由于没有任何先验知识,因此采用了一个随机策略,然后我们可以计算得到随机策略下的状态值函数,那么我们就可以进行策略改进了。具体的做法就是前面提到的单步搜索,向左时当前动作的回报为10,因此单步搜索的结果为10+<spanclass="math inline">\(\gamma\upsilon\)</span>,<spanclass="math inline">\(\upsilon\)</span>为下一个路口的值函数,而向右为100+<spanclass="math inline">\(\gamma\upsilon\)</span>,因此策略会更新为向右,而不再是随机了,显然策略被改进了。同时我们注意到,单步搜索计算的值正是动作值函数。</p><p>根据上面的例子,我们可以总结一下策略改进的方法:遍历所有的状态和所有可能的动作,采用贪婪算法进行策略的更新,即对所有<spanclass="math inline">\(s\in\bf S\)</span>,</p><p><span class="math display">\[\begin{split}\pi^{'}(s)&=\arg\max_{\mathbf{a}}q_{\pi}(s,a)\\&=\arg\max_{\mathbf{a}}\sum_{s^{'}}p(s^{'}\mids,a)\left[r(s,a,s^{'})+\gamma\upsilon_{\pi}(s^{'})\right]\end{split}\tag{3.2}\]</span></p><p>现在我们已经知道如何计算当前策略的状态值函数,也知道可以根据动作值函数来更新策略,那下面就来讲讲如何从零开始求解最优策略。</p><h3 id="策略迭代">策略迭代</h3>一旦策略<spanclass="math inline">\(\pi\)</span>通过策略改进得到一个更好的策略<spanclass="math inline">\(\pi^{'}\)</span>,那么我们就可以通过策略估计算法,计算策略<spanclass="math inline">\(\pi^{'}\)</span>的状态值函数,并用公式(3.2)进行策略改进得到一个比策略<spanclass="math inline">\(\pi^{'}\)</span>更好的策略<spanclass="math inline">\(\pi^{''}\)</span>。如下图所示,经过无数次的策略估计和策略改进后,我们终将会收敛于最优策略<spanclass="math inline">\(\pi_{*}\)</span>。这种通过不断迭代地去改进策略的方法叫做策略迭代(policyiteration)。<div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/pictures/c9c7ec7b0709d5492f5e8cb8a6096b7e.png?raw=true width=600></div></br> 下面为整个策略迭代的算法过程:<div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/pictures/1b44935438fee7046950fcfddfd405c0.png?raw=true width=600></div><h3 id="值迭代">值迭代</h3><p>策略迭代算法需要不断地进行策略估计和策略改进,每次策略估计和改进都需要遍历一次所有的状态和动作,因此算法的计算量非常大,效率非常低。同时可以看到策略迭代的依据是贝尔曼公式,而如果直接利用贝尔曼最优公式会不会加速求解过程呢?事实上是可以的,下面的值迭代(valueiteration)算法就是利用贝尔曼最优公式来提高求解效率的一种算法。</p><p>我们还是需要先迭代估计状态值函数,但不必每次迭代都进行策略改进。根据贝尔曼最优公式,可以直接用上一次迭代的最大动作值函数对当前迭代的状态值函数进行更新,如下所示:</p><p><span class="math display">\[\begin{split}\upsilon_{k+1}(s)&=\max_{\mathbf{a}} q_{k}(s,a) \\&=\max_{\mathbf{a}}\sum_{s^{'}}p(s^{'}\mid{s,a})\left[r(s,a,s^{'})+\gamma\upsilon_{k}(s^{'})\right]\end{split}\tag{3.3}\]</span></p><p>值迭代算法的好处就是省去了每次迭代时的策略改进过程,并且由于每次迭代得到的<spanclass="math inline">\(\upsilon_{k+1}(s)\)</span>都要<spanclass="math inline">\(\geq\)</span>策略迭代得到的<spanclass="math inline">\(\upsilon_{k+1}(s)\)</span>,也就是说相同迭代次数下,策略迭代得到的策略肯定没有值迭代得到的策略好,因此能大大加快算法收敛。直到值函数收敛到最优值函数后,再通过最优值函数来计算得到最优策略,下面是值迭代算法的完整过程:</p><div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/pictures/c94f41587e075ba0ab3af4a82ff99a17.png?raw=true width=560></div><p></br>一般来说值迭代和策略迭代都需要经过无数次迭代才能精确收敛到最优策略,而实践中我们往往会设定一个阈值<spanclass="math inline">\(\Delta\)</span>来作为迭代中止条件,即当所有的<spanclass="math inline">\(\upsilon_{\pi}(s)\)</span>变化量小于<spanclass="math inline">\(\Delta\)</span>时,我们就近似的认为获得了最优策略。值迭代和策略迭代都可以用来求解最优策略,但是都需要依赖一个现有的环境模型,而对环境进行精确建模往往是非常困难的,所以导致了动态规划方法在MDP求解时几乎不可用,当然如果状态转移是确定性的(<spanclass="math inline">\(p(s^{'}\mids,a)=1\)</span>),那就另当别论了。</p><h2 id="蒙特卡罗方法">蒙特卡罗方法</h2><p>下面我们要讲的是蒙特卡罗方法(Monte CarloMethods)。与动态规划不同,蒙特卡罗方法不需要知道环境的完整模型,仅仅需要经验就可以获得最优策略,这些经验可以通过与环境在线或模拟交互的方式获得。在线交互显然是不需要任何环境的先验知识,模拟交互虽然需要知道环境状态的转移,但与动态规划不同的是这里不需要知道具体的转移概率。</p><p>蒙特卡罗方法也称统计模拟方法,基本思想是通过对大量的重复随机事件进行统计,估计随机事件的概率分布或期望。一个典型的例子是利用蒙特卡罗方法计算圆周率。假设我们知道圆的面积公式为<spanclass="math inline">\(S=\pir^{2}\)</span>,那计算圆周率的公式自然就是<spanclass="math inline">\(\pi =\frac{S}{r^{2}}\)</span>,因此如果我们知道圆面积和圆半径,那么就可以求到圆周率。那么如何计算一个圆的面积呢?给定一个圆,我们可以画出这个圆的外切正方形,那么这个外切正方形的面积为<spanclass="math inline">\(S_{正方形}=4r^{2}\)</span>,现在我们往正方形区域随机投点,并统计点落在圆内的概率<spanclass="math inline">\(p\)</span>,那么圆面积可以这么计算:<spanclass="math inline">\(S_{圆}=p\cdot S_{正方形}\)</span>,因此<spanclass="math inline">\(\pi=4\cdotp\)</span>。可以想到,如果投点次数越多,<spanclass="math inline">\(p\)</span>估计越精确,<spanclass="math inline">\(\pi\)</span>的结果也就越接近真实值。</p><h3 id="蒙特卡罗策略估计">蒙特卡罗策略估计</h3><p>我们现在来考虑一下如何利用蒙特卡罗方法估计给定策略下的状态值函数。与上面计算圆周率的例子稍有不同的是,现在我们估计的是未来回报的期望,而不是概率,但基本思想是一样的。很显然,如果要估计<spanclass="math inline">\(\upsilon_{\pi}(s)\)</span>,我们首先需要根据给定策略生成大量的经验数据,然后从中统计从状态<spanclass="math inline">\(s\)</span>开始的未来回报的平均值,这个平均值就是我们估计的状态值函数。这种利用蒙特卡罗方法进行策略估计的算法又叫做蒙特卡罗策略估计(MonteCarlo Policy Evaluation)。</p><p>蒙特卡罗策略估计在具体实现时又分为first-visit MC methods和every-visitMC methods。由于在一个episode中,状态<spanclass="math inline">\(s\)</span>可能会出现多次,first-visit MCmethods就是只统计第一次到达该状态的未来回报,而every-visit MCmethods是所有达到该状态的未来回报都会统计累加起来。下面我们举例说明first-visitMC methods的估计方法<sup>[6]</sup>。</p><p>现在我们假设有如下一些样本(下图每一行都是在当前策略下的一个独立的episode),紫色实心点为状态<spanclass="math inline">\(s\)</span>,取折扣因子γ=1,即直接计算累积回报。</p><div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/pictures/221402112851854.png?raw=true></div><p></br> 第一个episode中到达过两次状态<spanclass="math inline">\(s\)</span>,我们只计算第一次的未来回报<spanclass="math inline">\(R_{1}(s)=1-2+0+1-3+5=2\)</span>。假设我们已经用相同的方法计算得到<spanclass="math inline">\(R_{2}(s)=1\)</span>,<spanclass="math inline">\(R_{3}(s)=-5\)</span>,<spanclass="math inline">\(R_{4}(s)=4\)</span>。那么当前策略下状态<spanclass="math inline">\(s\)</span>的值函数</p><p><span class="math display">\[\upsilon_{\pi}(s)={\bfE}\left[R(s)\right]=\frac{1}{N}\sum_{i=1}^{N}\left[R_{i}(s)\right]=\frac{1}{4}\left(2+1-5+4\right)=0.5\]</span></p><p>同样,如果生成的episode数量越多,<spanclass="math inline">\(\upsilon_{\pi}(s)\)</span>的估计就越接近真实值,下面是具体的算法流程:</p><div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/pictures/079fef1ab5cd0065007ae82d893b0520.png?raw=true width=560></div><p></br> 注意这里使用大写的<spanclass="math inline">\(V\)</span>表示状态值函数的估计,Sutton的理由是状态值函数一旦初始化,就会立即变成一个随机的值了,因为<spanclass="math inline">\(G\)</span>会随着生成的episode不同而不断变化。可以认为每次<spanclass="math inline">\(G\)</span>都为<spanclass="math inline">\(\upsilon_{\pi}(s)\)</span>的一个独立同分布估计,当数据量非常大时<spanclass="math inline">\(V(s)\)</span>将最终收敛于这个分布的均值。</p><h3 id="动作值函数的蒙特卡罗估计">动作值函数的蒙特卡罗估计</h3><p>由于我们没有完整的环境状态转移模型,因此即使我们得到当前策略的值函数,根据公式(3.2)也无法进行策略改进。既然我们可以估计得到状态值函数,那么肯定也可以用相同的方法直接估计动作值函数,在这里叫做动作值函数的蒙特卡罗估计(MonteCarlo Estimation of Action Values)。</p><p>估计方法跟蒙特卡罗策略估计差不多,只不过我们需要找到所有的状态动作对(pairof state <span class="math inline">\(s\)</span> and action <spanclass="math inline">\(a\)</span>),然后统计每一个状态动作对的未来回报的平均值,即<spanclass="math inline">\(q_{\pi}(s,a)\)</span>的估计值。得到了<spanclass="math inline">\(q_{\pi}(s,a)\)</span>,我们就可以根据公式(3.2)进行策略改进了。</p><h3 id="蒙特卡罗控制">蒙特卡罗控制</h3><p>蒙特卡罗控制(Monte CarloControl)首要的问题就是如何估计最优策略。跟之前动态规划一样,这里也可以采用策略迭代和策略改进交替进行的方式,经过大量的迭代后收敛到最优策略。但蒙特卡罗方法有一个最大的问题,即我们需要产生无数的episode才能保证收敛到最优结果。无数的episode和大量的迭代导致计算量巨大,效率非常低。Sutton在书<sup>[1]</sup>中提到两种解决方法,其中一种方法是采用episode-by-episode的方式进行优化。</p><p>episode-by-episode的思想与动态规划中值迭代的in-place版本非常相似。在动态规划的值迭代中,我们每次迭代都直接覆盖更新值函数,因此能及时地利用到更新后的值函数,从而能加快收敛。episode-by-episode则是先用当前策略生成一个episode,然后根据这个episode进行动作值函数的更新,同时更新策略,并利用更新后的策略继续生成后续的episode。</p><p>下面是exploring starts的蒙特卡罗控制(Monte Carlo ES,exploringstarts指的是从一个随机的开始状态和动作生成一个episode)算法的完整过程:</p><div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/pictures/608a1293a52fa134b5168042bf7fd519.png?raw=true width=560></div><p></br> 至于为何要使用exploringstarts,这与episode-by-episode在线生成episode的更新策略有关。还是上面的岔路口的例子,我们先随机指定一个策略,比如指定向左,那么使用该策略生成一个episode时必然也是向左,那么也就只能更新向左的动作值函数了,而无法更新向右的动作值函数。由于动作值函数是随机初始化的,如果向右的动作值函数初始值小于更新后的向左的动作值函数,那么下一次生成episode时仍然是向左,并且可以想象可能永远不会选择向右。但其实向右才是最优动作,因此上述更新的策略永远不可能是最优策略。但随机选择开始状态和动作,可以避免某些动作的值函数不会更新的问题,因此可以保证能获得最优策略。</p><p>当然也可以采用其他方法避免使用exploringstarts,下面要介绍的on-policy方法和off-policy方法就是其中的两种方法。</p><h3 id="on-policy蒙特卡罗控制">On-Policy蒙特卡罗控制</h3><p>前面的Monte Carlo ES算法使用exploringstarts是为了保证所有可能的动作值函数都能得到更新,从而保证能获得最优策略。如果策略本身就可以在任何状态下都采取所有可能的动作,而不是贪婪地只选择动作值函数最大的那个,那问题不就迎刃而解了吗。下面要讨论策略是非确定性的,也就是对于所有的状态<spanclass="math inline">\(s\)</span>和该状态下所有可能的动作<spanclass="math inline">\(a\)</span>都有<spanclass="math inline">\(\pi(a\mid s)>0\)</span>,并且用<spanclass="math inline">\(\epsilon-soft\)</span>策略生成episode。由于我们评估和改进的策略与生成episode的策略是相同的,因此叫做on-policy方法。</p><p>在<spanclass="math inline">\(\epsilon-soft\)</span>策略中,大多数时候策略会选择动作值函数最大的动作(或者换句话说,以<spanclass="math inline">\(1-\epsilon\)</span>的概率选择动作值函数最大的动作,<spanclass="math inline">\(\epsilon\)</span>是一个非常小的正数),但也会以概率<spanclass="math inline">\(\epsilon\)</span>从其他动作中随机挑选一个动作,整体算法流程:</p><div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/pictures/a50e2ce4a881eea7b6b1a2a830f2db1d.png?raw=true width=560></div><h3 id="off-policy蒙特卡罗控制">Off-Policy蒙特卡罗控制</h3><p>在off-policy方法中,生成episode的策略与评估和改进的策略并非同一个策略。其中生成episode的策略我们叫行为策略(behaviorpolicy),而评估和改进的策略叫估计策略(estimationpolicy)。这种方法的好处是可以使行为策略是<spanclass="math inline">\(\epsilon-soft\)</span>策略,但估计策略是确定性的。下面只给出算法流程,具体推导请参考Sutton在书中的介绍<sup>[1]</sup>。</p><div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/pictures/6f3c3cd1ddbcbfb3fe3df6dc881ce4b8.png?raw=true width=560></div><h2 id="时间差分学习">时间差分学习</h2><p>时间差分学习(temporal-dierence (TD)learning)结合了动态规划和蒙特卡罗方法的优点,与蒙特卡罗方法一样不需要环境模型,与动态规划一样更新估计值时只依赖于下一个状态可用的估计值,而不需要等到策略自举出完整的episode。</p><h3 id="td预测">TD预测</h3><p>TD预测(TDprediction)又叫TD策略估计,就是从给定的一系列经验数据中估计出当前策略的状态值函数<spanclass="math inline">\(\upsilon_{\pi}\)</span>。回顾一下蒙特卡罗控制,我们是先自举一个episode,然后根据历史episode和当前最新的episode计算从状态<spanclass="math inline">\(s\)</span>开始未来回报的均值,作为当前状态值函数的更新值。对上面更新方式稍做修改,我们可以用一种滑动平均的方法来更新,即只用当前episode的未来回报与状态值函数的差值来更新。一个简单的every-visitMC方法的更新公式就如下所示:</p><p><span class="math display">\[V(S_{t})=(1-\alpha)V(S_{t})+\alphaG_{t}=V(S_{t})+\alpha\left[G_{t}-V(S_{t}) \right]\tag{4-1}\]</span></p><p><span class="math inline">\(V(S_{t})\)</span>表示第<spanclass="math inline">\(t\)</span>个时刻为状态<spanclass="math inline">\(S_{t}\)</span>的状态值函数,<spanclass="math inline">\(G_{t}\)</span>表示从状态<spanclass="math inline">\(S_{t}\)</span>开始到episode结束时的总回报,<spanclass="math inline">\(\alpha\)</span>是一个常数步长参数(梯度下降算法中叫学习率),这个公式叫做<spanclass="math inline">\(constant-\alpha\)</span> MC。在这个公式中,<spanclass="math inline">\(G_{t}\)</span>是需要等到整个episode结束才能得到的,因此只有在自举完整的episode后才能进行更新。下面要说的TD算法就很好地解决了这个问题,只需要等到下一个时刻转移到下一个状态和获得回报值。下面是一种最简单的TD算法,叫做TD(0)。</p><p><spanclass="math display">\[V(S_{t})=V(S_{t})+\alpha\left[R_{t+1}+\gammaV(S_{t+1})-V(S_{t}) \right]\tag{4-2}\]</span></p><p>我们这里只是用<span class="math inline">\(R_{t+1}+\gammaV(S_{t+1})\)</span>来估计<spanclass="math inline">\(constant-\alpha\)</span>MC中未来回报的真实值。与蒙特卡罗控制一样,TD(0)也能确保收敛到最优状态值函数,当然前提也是需要大量的经验数据。至于TD(0)与蒙特卡罗控制哪个算法收敛更快,这个问题并没有准确的答案,不过Sutton在书中指出,在一些随机任务上TD(0)比<spanclass="math inline">\(constant-\alpha\)</span>MC收敛更快。TD(0)算法在每个时刻都要进行一次更新,更高效的方法是在训练时使用batchupdating的方式,即一个batch进行一次更新。</p><p>显然,TD learning相比MC有以下优点<sup>[7]</sup>:</p><ul><li>由于TD预测使用差值进行更新,加上步进参数<spanclass="math inline">\(\alpha\)</span>的存在,TDlearning的更新更平稳,方差更小。</li><li>TDlearning可以用于在线训练,因为不需要等到整个episode结束才更新。</li><li>TD learning应用更广,可以用于非有限步数的情况。</li></ul><p>但也存在一些缺点,比如TDlearning对初始值比较敏感,以及收敛结果是有偏的。</p><h3 id="tdλ">TD(λ)</h3><p>在介绍TD(λ)之前,我们先介绍一下n-StepTD预测。前面介绍的TD(0)算法在当前状态的基础上往后执行一步就可以进行更新,并且在更新时使用了贝尔曼公式对当前状态的未来回报进行估计,那我们是不是也可以往后执行n步之后再更新,这样用贝尔曼公式估计的未来回报是不是会更加精确呢?实际上,当n等于整个episode的总步数时,n-StepTD预测就完全成了MC估计了。</p><div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/强化学习/8aabe6f419dfeca3f4ee9de376ceb3bd.png?raw=true width=540></div><p></br></p><p>对于1-step来说,未来回报的值等于第一个回报值加上下一个状态值函数折扣后的值,用公式表示:</p><p><span class="math display">\[G_{t}^{(1)}=R_{t+1}+\gammaV(S_{t+1})\]</span></p><p>2-step比1-step多执行一步,其未来回报值为:</p><p><span class="math display">\[G_{t}^{(2)}=R_{t+1}+\gammaR_{t+2}+\gamma^{2} V(S_{t+2})\]</span></p><p>那么n-step的未来回报值为:</p><p><span class="math display">\[G_{t}^{(n)}=R_{t+1}+\gammaR_{t+2}+\gamma^{2} V(S_{t+2})+...+\gamma^{n}V(S_{t+n})\]</span></p><p>在公式(4-1)中我们用<spanclass="math inline">\(G_{t}^{(n)}\)</span>替代<spanclass="math inline">\(G_{t}\)</span>,最后n-StepTD预测的更新公式为:</p><p><spanclass="math display">\[V(S_{t})=V(S_{t})+\alpha\left[G_{t}^{(n)}-V(S_{t})\right]\tag{4-3}\]</span></p>n-StepTD预测一定程度上可以使得估计的值函数更准确,因此收敛效果会更好,但更新时需要等待的步数增加了。下图是使用n-StepTD方法在random walk任务上的RMS error对比。<div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/强化学习/3a775aa18ad1b86a07d3b75d52b1c25c.png?raw=true width=600></div><p></br> n-Step TD只使用了从当前状态开始执行n步未来回报的估计值<spanclass="math inline">\(G_{t}^{(n)}\)</span>,其实为了充分利用中间每个step的信息,也可以使用不同的n对应的<spanclass="math inline">\(G_{t}^{(n)}\)</span>的平均值。比如可以把2-step和4-step的均值作为<spanclass="math inline">\(G_{t}\)</span>的估计值,</p><p><spanclass="math display">\[G_{t}^{avg}=\frac{1}{2}G_{t}^{(2)}+\frac{1}{2}G_{t}^{(4)}\]</span></p><p>TD(λ)也可以理解为一种特殊的n-step平均算法,每个n-step的权重为<spanclass="math inline">\((1-\lambda)\lambda^{(n-1)}\)</span>,所有权重和仍然为1,因此有:</p><p><spanclass="math display">\[G_{t}^{(\lambda)}=(1-\lambda)\sum_{n=1}^{\infty}\lambda^{n-1}G_{t}^{(n)}\tag{4-4}\]</span></p><p>公式(4-4)表示的是没有终止状态的情况,对于最终存在终止状态的episode任务或截断任务<sup>[注1]</sup>来讲,为了保证所有权重的和为1,最后一个n-step的权重被设置为<spanclass="math inline">\(\lambda^{T-t-1}\)</span>,其中<spanclass="math inline">\(T\)</span>为episode总步数。</p><p><spanclass="math display">\[G_{t}^{(\lambda)}=(1-\lambda)\sum_{n=1}^{T-t-1}\lambda^{n-1}G_{t}^{(n)}+\lambda^{T-t-1}G_{t}\tag{4-5}\]</span></p><p>当<spanclass="math inline">\(\lambda=1\)</span>时,这时TD(λ)就相当于MC,而当<spanclass="math inline">\(\lambda=0\)</span>时,TD(λ)就退化成了TD(0)。</p><div data-align="center"><img src="https://github.com/hjchen2/personal/blob/master/blog/%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0/294acecc263a9668bd48e3403f9b5225.png?raw=true" width=540></div><p></br></p><h3 id="sarsa">Sarsa</h3><p>接下来我们考虑一下如何使用TD预测进行策略改进。首先我们知道可以使用TD预测来估计状态值函数,并且可以使用公式(3-2)进行策略改进。但问题来了,公式(3-2)中的<spanclass="math inline">\(p(s^{'}\mids,a)\)</span>是未知参数,无法直接进行策略改进。回顾一下蒙特卡洛控制方法,TD也可以直接对动作值函数<spanclass="math inline">\(q_{\pi}\)</span>进行估计。与<spanclass="math inline">\(\upsilon_{\pi}\)</span>的更新公式一样,下面是<spanclass="math inline">\(q_{\pi}\)</span>的更新公式,</p><p><spanclass="math display">\[Q(S_t,A_t)=Q(S_t,A_t)+\alpha[R_{t+1}+\gammaQ(S_{t+1},A_{t+1})-Q(S_t,A_t)]\tag{4-3}\]</span></p>有了状态值函数,接下来就可以使用公式(3-2)进行策略改进了。在公式(4-3)中,每次非结束状态<spanclass="math inline">\(S_t\)</span>转移到下一个状态时都进行一次值函数的更新,每次更新都只与<spanclass="math inline">\((S_t,A_t,R_{t+1},S_{t+1},A_{t+1})\)</span>有关,因此叫做Sarsa算法。如果状态<spanclass="math inline">\(S_{t+1}\)</span>为终止状态,则<spanclass="math inline">\(Q(S_{t+1},A_{t+1})=0\)</span>。下面是Sarsa <spanclass="math inline">\(\epsilon-greedy\)</span>算法的完整过程,由于评估和改进时采用的策略与生成episode的策略是同一个策略,因此Sarsa算法是一种on-policy方法。<div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/强化学习/a8d5cc18d1df07802931d29487b29542.png?raw=true width=600></div><p></br> Sarsa的<span class="math inline">\(Q\)</span>值更新公式与<spanclass="math inline">\(TD(0)\)</span>一致,实际上也可以采用<spanclass="math inline">\(TD(λ)\)</span>的形式进行<spanclass="math inline">\(Q\)</span>值更新,这个改进算法就是Sarsa(λ)。关于Sarsa(λ)的具体介绍请参考《ReinforcementLearning: An Introduction》一书第七章。</p><h3 id="q-learning">Q-Learning</h3><p>下面介绍的Q学习是一种off-policy方法,并被认为是强化学习算法最重要的突破之一。在Q-learning中,动作值函数的更新完全独立于生成episode的策略,使得学习到的<spanclass="math inline">\(Q(S_t,A_t)\)</span>直接是最优动作值函数<spanclass="math inline">\(q_{*}\)</span>的估计值。</p><p><spanclass="math display">\[Q(S_t,A_t)=Q(S_t,A_t)+\alpha[R_{t+1}+\gamma\mathop \max_{a} Q(S_{t+1},a)-Q(S_t,A_t)]\tag{4-4}\]</span></p><p>公式(4-4)为Q-learning的单步更新公式,与Sarsa唯一的不同是:类似于动态规划中的值迭代算法,Q学习也是直接使用最优的<spanclass="math inline">\(Q(S_{t+1},A_{t+1})\)</span>进行更新,也就相当于策略只采用了最大<spanclass="math inline">\(Q\)</span>值对应的动作。Q-learning简化了算法分析和收敛性证明的难度,使得它的收敛性很早就得到了证明。但与前面介绍的蒙特卡洛控制一样,由于每次只选择<spanclass="math inline">\(Q\)</span>值最大的动作,因此这个算法也会导致部分state-action对不会被策略生成,相应的动作值函数也无法得到更新。为了确保能收敛到最优策略,下面的算法在生成episode时同样使用了<spanclass="math inline">\(\epsilon-greedy\)</span>策略,但更新时仍然采用确定性策略(即策略只选择<spanclass="math inline">\(Q\)</span>值最大的动作)。</p><div data-align="center"><img src=https://github.com/hjchen2/personal/blob/master/blog/强化学习/3e01e229dc9f53393a25ded669fc0971.png?raw=true width=600></div><h2 id="dqn">DQN</h2><h2 id="dqn改进算法">DQN改进算法</h2><h2 id="强化学习在内容推荐中的应用">强化学习在内容推荐中的应用</h2><h2 id="参考资料">参考资料</h2><p>1、Reinforcement Learning: An Introduction, Richard S. Sutton andAndrew G. Barto,2012<br />2、Playing Atari with Deep Reinforcement Learning,DeepMindTechnologies,Arxiv 2013.12<br />3、Human-level control through deep reinforcement learning,DeepMindTechnologies,Nature 2015.02<br />4、DeepMind官网https://deepmind.com/blog/deep-reinforcement-learning<br />5、https://www.nervanasys.com/demystifying-deep-reinforcement-learning<br />6、http://www.cnblogs.com/jinxulin/p/3511298.html<br />7、Introduction to Reinforcement Learning,David Silver</p><h2 id="注释">注释</h2><p>1、截断任务:在强化学习中,非episode任务由于不存在终止状态,为了便于训练可以将非episode任务截断成episode。</p>]]></content>
<summary type="html"><h2 id="前言">前言</h2>
<p>近几年,由于DeepMind成功地将强化学习(reinforcement
learning)运用在AlphaGo上,机器首次在复杂任务上取得了超过人类的表现,使得强化学习成为目前机器学习研究的前沿方向之一。强化学习由来已久,Sutton等在1979年就已经开始研究强化学习,1998年出版了强化学习介绍一书,并于2012年发布第二版,本文前几部分内容主要参考该书。</p></summary>
<category term="reinforcement learning" scheme="https://hjchen2.github.io/categories/reinforcement-learning/"/>
<category term="reinforcement learning" scheme="https://hjchen2.github.io/tags/reinforcement-learning/"/>
<category term="machine learning" scheme="https://hjchen2.github.io/tags/machine-learning/"/>
</entry>
</feed>