-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinterpreter.ml
1693 lines (1547 loc) · 60.8 KB
/
interpreter.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
let debug_prints = ref false
let incr_mut r = r := !r + 1
let add_to_mut r v = r := !r + v
type heap_id = int [@@deriving show]
let gen_heap_id : unit -> heap_id =
let i = ref 0 in
fun () ->
incr i;
!i
type scalar_value =
| SNumber of Pico_number.t
| SNumberInterval of Pico_number_interval.t
| SBool of bool
| SUnknownBool
| SString of string
| SNil of string option
| SPointer of heap_id
| SNilPointer of string
[@@deriving show]
type vector_value =
| VNumber of Pico_number.t Array.t
| VNumberInterval of Pico_number_interval.t Array.t
| VBool of bool Array.t
let seq_of_vector = function
| VNumber a -> a |> Array.to_seq |> Seq.map (fun v -> SNumber v)
| VNumberInterval a ->
a |> Array.to_seq |> Seq.map (fun v -> SNumberInterval v)
| VBool a -> a |> Array.to_seq |> Seq.map (fun v -> SBool v)
let show_vector_value vec =
let s =
vec |> seq_of_vector |> Seq.map show_scalar_value |> List.of_seq
|> String.concat ", "
in
match vec with
| VNumber _ -> Printf.sprintf "VNumber[%s]" s
| VNumberInterval _ -> Printf.sprintf "VNumberInterval[%s]" s
| VBool _ -> Printf.sprintf "VBool[%s]" s
let pp_vector_value fmt v = Format.pp_print_string fmt @@ show_vector_value v
type value = Scalar of scalar_value | Vector of vector_value [@@deriving show]
let length_of_vector_unchecked = function
| VNumber a -> Array.length a
| VNumberInterval a -> Array.length a
| VBool a -> Array.length a
let length_of_vector vec =
let l = length_of_vector_unchecked vec in
assert (l > 1);
l
let vector_compare_indices vec i j =
try
match vec with
| VNumber a -> Pico_number.compare a.(i) a.(j)
| VNumberInterval a -> Pico_number_interval.compare a.(i) a.(j)
| VBool a -> compare a.(i) a.(j)
with exn -> raise exn
let value_unvectorize_if_possible (value : value) : value =
match value with
| Scalar _ -> value
| Vector vec ->
let seq = seq_of_vector vec in
let first, _ = Option.get @@ Seq.uncons seq in
if Seq.for_all (fun v -> v = first) seq then Scalar first else value
let vector_extract_by_indices indices vec =
let extracted =
match vec with
| VNumber a ->
VNumber (Array.init (Array.length indices) (fun i -> a.(indices.(i))))
| VNumberInterval a ->
VNumberInterval
(Array.init (Array.length indices) (fun i -> a.(indices.(i))))
| VBool a ->
VBool (Array.init (Array.length indices) (fun i -> a.(indices.(i))))
in
Vector extracted |> value_unvectorize_if_possible
let seq_of_value vector_size = function
| Scalar s -> Seq.repeat s |> Seq.take vector_size
| Vector v ->
assert (length_of_vector v = vector_size);
seq_of_vector v
let example_of_vector vec =
vec |> seq_of_vector |> Seq.uncons |> Option.get |> fst
let can_vectorize_scalar_value = function
| SNumber _ -> true
| SNumberInterval _ -> true
| _ -> false
let can_vectorize_value = function
| Scalar s -> can_vectorize_scalar_value s
| Vector vec -> can_vectorize_scalar_value @@ example_of_vector vec
let value_of_non_empty_seq seq =
let vector_of_seq_example seq scalar_example =
let fail_mixed () = failwith "Cannot build vector of mixed types" in
match scalar_example with
| SNumber _ ->
VNumber
(seq
|> Seq.map (function SNumber n -> n | _ -> fail_mixed ())
|> Array.of_seq)
| SNumberInterval _ ->
VNumberInterval
(seq
|> Seq.map (function SNumberInterval n -> n | _ -> fail_mixed ())
|> Array.of_seq)
| SBool _ ->
VBool
(seq
|> Seq.map (function SBool b -> b | _ -> fail_mixed ())
|> Array.of_seq)
| _ ->
assert (not @@ can_vectorize_scalar_value scalar_example);
failwith
@@ Printf.sprintf "Cannot build vector from %s"
@@ show_scalar_value scalar_example
in
let scalar_example, _ = Option.get @@ Seq.uncons seq in
Vector (vector_of_seq_example seq scalar_example)
|> value_unvectorize_if_possible
let map_vector (f : scalar_value -> scalar_value) vec =
vec |> seq_of_vector |> Seq.map f |> value_of_non_empty_seq
let filter_vector (mask : bool Array.t) vec =
incr_mut Perf.global_counters.filter_vector;
assert (Array.length mask = length_of_vector vec);
let mask_true_count =
Array.fold_left (fun acc v -> if v then acc + 1 else acc) 0 mask
in
if Array.length mask = mask_true_count then Some (Vector vec)
else if mask_true_count = 0 then None
else (
incr_mut Perf.global_counters.filter_vector_real;
let indices =
mask |> Array.to_seq
|> Seq.mapi (fun i b -> if b then Some i else None)
|> Seq.filter_map (fun x -> x)
|> Array.of_seq
in
assert (Array.length indices = mask_true_count);
Some (vector_extract_by_indices indices vec))
let map2_vector (f : scalar_value -> scalar_value -> scalar_value) vec1 vec2 =
let length1 = length_of_vector vec1 in
let length2 = length_of_vector vec2 in
assert (length1 = length2);
Seq.zip (seq_of_vector vec1) (seq_of_vector vec2)
|> Seq.map (fun (v1, v2) -> f v1 v2)
|> value_of_non_empty_seq
module ListForArrayTable = struct
type t = heap_id Array.t
let empty = [||]
let is_empty t = Array.length t = 0
let get t i = t.(i)
let append t item = Array.append t [| item |]
let map = Array.map
let length = Array.length
let nth_opt t i = if i < Array.length t then Some t.(i) else None
let to_seq = Array.to_seq
let drop_last t =
if Array.length t = 0 then None
else Some (Array.sub t 0 (Array.length t - 1))
let show t =
Printf.sprintf "[%s]"
(t |> Array.to_seq |> Seq.map string_of_int |> List.of_seq
|> String.concat ", ")
let pp fmt t = Format.pp_print_string fmt @@ show t
end
type heap_value =
| HValue of value
| HObjectTable of (string * heap_id) list
| HArrayTable of ListForArrayTable.t
| HUnknownTable
| HClosure of Ir.global_id * value list
| HBuiltinFun of string
[@@deriving show]
module HeapIdMap = Map.Make (struct
type t = heap_id
let compare = Stdlib.compare
end)
module HeapValueMap = Map.Make (struct
type t = heap_value
let compare = Stdlib.compare
end)
module Heap = struct
type t = {
old_values : heap_value Array.t;
old_changed : bool Array.t; (* mutated, but only from false to true *)
new_values : heap_value HeapIdMap.t;
}
type storage_location = Old | New | NewOrOld
let find_storage_location (heap_id : heap_id) (heap : t) =
if heap_id >= Array.length heap.old_values then New
else if heap.old_changed.(heap_id) then NewOrOld
else Old
let empty =
{ old_values = [||]; old_changed = [||]; new_values = HeapIdMap.empty }
let find (heap_id : heap_id) (heap : t) : heap_value =
let find_old () = heap.old_values.(heap_id) in
let find_new_opt () = HeapIdMap.find_opt heap_id heap.new_values in
match find_storage_location heap_id heap with
| Old -> find_old ()
| New -> Option.get @@ find_new_opt ()
| NewOrOld -> (
match find_new_opt () with Some v -> v | None -> find_old ())
let add (heap_id : heap_id) (value : heap_value) (heap : t) : t =
if find_storage_location heap_id heap <> New then
heap.old_changed.(heap_id) <- true;
{ heap with new_values = HeapIdMap.add heap_id value heap.new_values }
let old_of_list (s : (heap_id * heap_value) List.t) : t =
List.iteri (fun i (heap_id, _) -> assert (heap_id = i)) s;
{
old_values = s |> List.to_seq |> Seq.map snd |> Array.of_seq;
old_changed = Array.make (List.length s) false;
new_values = HeapIdMap.empty;
}
let old_of_seq (s : (heap_id * heap_value) Seq.t) : t =
s |> List.of_seq |> old_of_list
let seq_of_old (heap : t) : (heap_id * heap_value) Seq.t =
assert (HeapIdMap.is_empty heap.new_values);
Array.to_seqi heap.old_values
let map (f : heap_value -> heap_value) (heap : t) : t =
{
old_values = heap.old_values |> Array.map f;
old_changed = heap.old_changed;
new_values = heap.new_values |> HeapIdMap.map f;
}
let debug_show (heap : t) : string =
Printf.sprintf "Old: %d [%s]\nNew: %d"
(Array.length heap.old_values)
(heap.old_values |> Array.to_seq |> Seq.map show_heap_value |> List.of_seq
|> String.concat ", ")
(HeapIdMap.cardinal heap.new_values)
end
module StringMap = Map.Make (struct
type t = string
let compare = Stdlib.compare
end)
let show_string_id_map show_v s =
s |> StringMap.bindings
|> List.map (fun (k, v) -> Printf.sprintf "\"%s\" -> %s" k (show_v v))
|> String.concat "; "
type state = {
heap : Heap.t;
local_env : value Ir.LocalIdMap.t;
outer_local_envs : value Ir.LocalIdMap.t list;
global_env : heap_id StringMap.t;
prints : string list;
vector_size : int;
}
type builtin_fun = state -> value list -> (state * value) list
let failwith_not_pointer (v : value) =
match v with
| Scalar (SPointer _) | Scalar (SNilPointer _) ->
failwith "failwith_not_pointer called with pointer"
| Scalar (SNil (Some hint)) ->
failwith
@@ Printf.sprintf "Value is not a pointer (value is nil; %s)" hint
| Scalar (SNil None) ->
failwith @@ Printf.sprintf "Value is not a pointer (value is nil)"
| _ -> failwith "Value is not a pointer"
let heap_id_from_pointer_local (state : state) (local_id : Ir.local_id) :
heap_id =
match Ir.LocalIdMap.find local_id state.local_env with
| Scalar (SPointer heap_id) -> heap_id
| Scalar (SNilPointer hint) ->
failwith @@ Printf.sprintf "Attempted to dereference nil (%s)" hint
| v -> failwith_not_pointer v
let state_heap_add (state : state) (heap_value : heap_value) : state * heap_id =
let heap_id = gen_heap_id () in
let state = { state with heap = Heap.add heap_id heap_value state.heap } in
(state, heap_id)
let state_heap_update (state : state) (update : heap_value -> heap_value)
(heap_id : heap_id) : state =
let old_heap_value = Heap.find heap_id state.heap in
let new_heap_value = update old_heap_value in
{ state with heap = Heap.add heap_id new_heap_value state.heap }
let map_value_references f v : value =
match v with
| Scalar (SNumber _) -> v
| Scalar (SNumberInterval _) -> v
| Scalar (SBool _) -> v
| Scalar SUnknownBool -> v
| Scalar (SString _) -> v
| Scalar (SNil _) -> v
| Scalar (SPointer heap_id) -> Scalar (SPointer (f heap_id))
| Scalar (SNilPointer _) -> v
| Vector v -> Vector v
let map_heap_value_references f v : heap_value =
match v with
| HValue value -> HValue (map_value_references f value)
| HObjectTable items -> HObjectTable (List.map (fun (k, v) -> (k, f v)) items)
| HArrayTable items -> HArrayTable (ListForArrayTable.map f items)
| HUnknownTable -> HUnknownTable
| HClosure (global_id, captures) ->
HClosure (global_id, List.map (map_value_references f) captures)
| HBuiltinFun name -> HBuiltinFun name
let gc_heap (state : state) : state =
Perf.count_and_time Perf.global_counters.gc @@ fun () ->
let old_heap = state.heap in
let new_heap_values = ref [] in
let new_ids_by_old_ids = ref HeapIdMap.empty in
let next_id = ref 0 in
let rec visit old_id =
match HeapIdMap.find_opt old_id !new_ids_by_old_ids with
| Some new_id -> new_id
| None ->
let new_id = !next_id in
next_id := !next_id + 1;
new_ids_by_old_ids := HeapIdMap.add old_id new_id !new_ids_by_old_ids;
(* ref is so that we can add to the list before recursing *)
let visited_value_ref = ref None in
new_heap_values := (new_id, visited_value_ref) :: !new_heap_values;
visited_value_ref :=
Some (map_heap_value_references visit (Heap.find old_id old_heap));
new_id
in
let state =
{
heap = Heap.empty;
global_env = StringMap.map visit state.global_env;
local_env = Ir.LocalIdMap.map (map_value_references visit) state.local_env;
outer_local_envs =
List.map
(Ir.LocalIdMap.map @@ map_value_references visit)
state.outer_local_envs;
prints = state.prints;
vector_size = state.vector_size;
}
in
{
state with
heap =
!new_heap_values
|> List.rev_map (fun (id, v) -> (id, Option.get !v))
|> Heap.old_of_list;
}
let normalize_state_maps_except_heap (state : state) : state =
(* The = operator for maps considers the internal tree structure, not just the
contained values like Map.equal. The result of this function is normalized so
that = works correctly for our state by rebuilding all maps so that their
internal tree structure is identical if their values are identical. *)
Perf.count_and_time Perf.global_counters.normalize_state_maps_except_heap
@@ fun () ->
{
heap = state.heap;
global_env = state.global_env |> StringMap.to_seq |> StringMap.of_seq;
local_env = state.local_env |> Ir.LocalIdMap.to_seq |> Ir.LocalIdMap.of_seq;
outer_local_envs =
List.map
(fun local_env ->
local_env |> Ir.LocalIdMap.to_seq |> Ir.LocalIdMap.of_seq)
state.outer_local_envs;
prints = state.prints;
vector_size = state.vector_size;
}
let state_map_values (f : value -> value) (state : state) : state =
let f_heap_value = function HValue v -> HValue (f v) | v -> v in
{
heap = Heap.map f_heap_value state.heap;
local_env = Ir.LocalIdMap.map f state.local_env;
outer_local_envs = List.map (Ir.LocalIdMap.map f) state.outer_local_envs;
global_env = state.global_env;
prints = state.prints;
vector_size = state.vector_size;
}
let state_assert_vector_lengths (state : state) =
ignore
@@ state_map_values
(function
| Scalar s -> Scalar s
| Vector vec ->
let l = length_of_vector vec in
assert (l > 1);
assert (l = state.vector_size);
Vector vec)
state
let normalize_state (state : state) : state =
if !debug_prints then Printf.printf "normalize_state\n";
let state = state |> normalize_state_maps_except_heap |> gc_heap in
state_assert_vector_lengths state;
state
module StateSet = Set.Make (struct
type t = state
let compare = Stdlib.compare
end)
module StateMap = Map.Make (struct
type t = state
let compare = Stdlib.compare
end)
module LazyStateSet = struct
type t =
| NormalizedSet of StateSet.t
| NormalizedList of state list
| NonNormalizedList of state list
let empty = NormalizedList []
let is_empty = function
| NormalizedSet set -> StateSet.is_empty set
| NormalizedList [] -> true
| NormalizedList _ -> false
| NonNormalizedList [] -> true
| NonNormalizedList _ -> false
let to_non_normalized_non_deduped_seq (t : t) : state Seq.t =
match t with
| NormalizedSet set -> StateSet.to_seq set
| NormalizedList list -> List.to_seq list
| NonNormalizedList list -> List.to_seq list
let to_normalized_non_deduped_seq (t : t) : state Seq.t =
match t with
| NormalizedSet set -> StateSet.to_seq set
| NormalizedList list -> List.to_seq list
| NonNormalizedList list -> List.to_seq list |> Seq.map normalize_state
let of_list (list : state list) : t = NonNormalizedList list
let has_normalized_elements (t : t) : bool =
match t with
| NormalizedSet _ -> true
| NormalizedList _ -> true
| NonNormalizedList [] -> true
| NonNormalizedList _ -> false
let to_normalized_state_set (t : t) : StateSet.t =
match t with
| NormalizedSet t -> t
| _ -> t |> to_normalized_non_deduped_seq |> StateSet.of_seq
let normalize (t : t) : t = NormalizedSet (to_normalized_state_set t)
let normalize_no_dedup (t : t) : t =
NormalizedList (to_normalized_non_deduped_seq t |> List.of_seq)
let union (a : t) (b : t) : t =
if is_empty a then b
else if is_empty b then a
else
let ab_list =
Seq.append
(to_non_normalized_non_deduped_seq a)
(to_non_normalized_non_deduped_seq b)
|> List.of_seq
in
if has_normalized_elements a && has_normalized_elements b then
NormalizedList ab_list
else NonNormalizedList ab_list
let union_diff (a : t) (b : t) : t * t =
if !debug_prints then Printf.printf "union diff b\n";
let b = to_normalized_state_set b in
if !debug_prints then Printf.printf "union diff rest\n";
let union, diff =
Seq.fold_left
(fun (union, diff) v ->
let new_union = StateSet.add v union in
if new_union == union then (union, diff) else (new_union, v :: diff))
(b, [])
(to_normalized_non_deduped_seq a)
in
(NormalizedSet union, NormalizedList diff)
let map (f : state -> state) (t : t) : t =
let changed = ref false in
let new_list =
t |> to_non_normalized_non_deduped_seq
|> Seq.map (fun state ->
let new_state = f state in
if new_state != state then changed := true;
new_state)
|> List.of_seq
in
if !changed then NonNormalizedList new_list else t
let filter (f : state -> bool) = function
| NormalizedSet set -> NormalizedSet (StateSet.filter f set)
| NormalizedList list -> NormalizedList (List.filter f list)
| NonNormalizedList list -> NonNormalizedList (List.filter f list)
let filter_map (f : state -> state option) = function
| NormalizedSet set -> NormalizedSet (StateSet.filter_map f set)
| NormalizedList list -> NormalizedList (List.filter_map f list)
| NonNormalizedList list -> NonNormalizedList (List.filter_map f list)
let cardinal_upper_bound (t : t) : int =
match t with
| NormalizedSet set -> StateSet.cardinal set
| NormalizedList list -> List.length list
| NonNormalizedList list -> List.length list
end
module StateAndReturnSet = Set.Make (struct
type t = state * value
let compare = Stdlib.compare
end)
module LazyStateAndReturnSet = struct
type t =
| NormalizedSet of StateAndReturnSet.t
| NonNormalizedList of (state * value) list
let empty = NonNormalizedList []
let is_empty = function
| NormalizedSet set -> StateAndReturnSet.is_empty set
| NonNormalizedList [] -> true
| NonNormalizedList _ -> false
let to_non_normalized_non_deduped_seq (t : t) : (state * value) Seq.t =
match t with
| NormalizedSet set -> StateAndReturnSet.to_seq set
| NonNormalizedList list -> List.to_seq list
let to_normalized_non_deduped_seq (t : t) : (state * value) Seq.t =
match t with
| NormalizedSet set -> StateAndReturnSet.to_seq set
| NonNormalizedList list ->
list |> List.to_seq
|> Seq.map (fun (state, value) -> (normalize_state state, value))
let of_list (list : (state * value) list) : t = NonNormalizedList list
let to_normalized_state_and_return_set (t : t) : StateAndReturnSet.t =
match t with
| NormalizedSet set -> set
| NonNormalizedList _ ->
t |> to_normalized_non_deduped_seq |> StateAndReturnSet.of_seq
let normalize (t : t) : t =
NormalizedSet (to_normalized_state_and_return_set t)
let union (a : t) (b : t) : t =
if is_empty a then b
else if is_empty b then a
else
NonNormalizedList
(Seq.append
(to_non_normalized_non_deduped_seq a)
(to_non_normalized_non_deduped_seq b)
|> List.of_seq)
let cardinal_upper_bound (t : t) : int =
match t with
| NormalizedSet set -> StateAndReturnSet.cardinal set
| NonNormalizedList list -> List.length list
end
module StateAndMaybeReturnSet = struct
type t =
| StateSet of LazyStateSet.t
| StateAndReturnSet of LazyStateAndReturnSet.t
let union (a : t) (b : t) : t =
match (a, b) with
| StateSet a, StateSet b -> StateSet (LazyStateSet.union a b)
| StateAndReturnSet a, StateAndReturnSet b ->
StateAndReturnSet (LazyStateAndReturnSet.union a b)
| _ -> failwith "Cannot union StateSet and StateAndReturnSet"
(* union_diff a b = (union a b, diff a b) (only when = is Set.equal though)*)
let union_diff (a : t) (b : t) : t * t =
match (a, b) with
| StateSet a, StateSet b ->
let union, diff = LazyStateSet.union_diff a b in
(StateSet union, StateSet diff)
| StateAndReturnSet _, StateAndReturnSet _ ->
failwith "not implemented (yet?)"
| _ -> failwith "Cannot union_diff StateSet and StateAndReturnSet"
let is_empty (a : t) : bool =
match a with
| StateSet a -> LazyStateSet.is_empty a
| StateAndReturnSet a -> LazyStateAndReturnSet.is_empty a
end
let zip_seq_list (seq_list : 'a Seq.t list) : 'a list Seq.t =
let rec zip_seq_list_helper (seq_list : 'a Seq.t list) : 'a list Seq.t =
match seq_list with
| [] -> Seq.empty
| first_seq :: _ -> (
match Seq.uncons first_seq with
| None ->
assert (List.for_all (fun seq -> Seq.is_empty seq) seq_list);
Seq.empty
| Some _ ->
let seq_heads, seq_tails =
seq_list
|> List.map (fun v -> Option.get @@ Seq.uncons v)
|> List.split
in
fun () -> Seq.Cons (seq_heads, zip_seq_list_helper seq_tails))
in
zip_seq_list_helper seq_list
let zip_map_map (f : 'va list -> 'vb) (to_seq : 'ma -> ('k * 'va) Seq.t)
(of_seq : ('k * 'vb) Seq.t -> 'mb) (maps : 'ma list) : 'm =
maps |> List.rev_map to_seq |> zip_seq_list
|> Seq.map (fun kv_list : ('k * 'v) ->
let k, v_list =
Option.get
@@ List.fold_right
(fun (k, v) acc ->
match acc with
| Some (only_key, values) ->
assert (k = only_key);
Some (k, v :: values)
| None -> Some (k, [ v ]))
kv_list None
in
(k, f v_list))
|> of_seq
let zip_map_state_values (f : (value * int) list -> value) (shape : state)
(states : state list) : state =
let f_heap_value (heap_values : (heap_value * int) list) : heap_value =
let values =
List.map
(function HValue v, vector_size -> Some (v, vector_size) | _ -> None)
heap_values
in
match values with
| Some _ :: _ -> HValue (values |> List.map Option.get |> f)
| None :: _ ->
assert (List.for_all Option.is_none values);
heap_values |> List.hd |> fst
| [] -> assert false
in
let lift_to_seq to_seq (to_seq_input, vector_size) =
to_seq_input |> to_seq |> Seq.map (fun (k, v) -> (k, (v, vector_size)))
in
let extract f = List.map (fun state -> (f state, state.vector_size)) states in
{
heap =
zip_map_map f_heap_value
(lift_to_seq Heap.seq_of_old)
Heap.old_of_seq
(extract (fun state -> state.heap));
local_env =
zip_map_map f
(lift_to_seq Ir.LocalIdMap.to_seq)
Ir.LocalIdMap.of_seq
(extract (fun state -> state.local_env));
outer_local_envs =
List.mapi
(fun i _ ->
zip_map_map f
(lift_to_seq Ir.LocalIdMap.to_seq)
Ir.LocalIdMap.of_seq
(extract (fun state -> List.nth state.outer_local_envs i)))
shape.outer_local_envs;
global_env = shape.global_env;
prints = shape.prints;
vector_size = shape.vector_size;
}
let rec normalize_value_for_shape = function
| v when not @@ can_vectorize_value v -> v
| Scalar (SNumber _) -> Scalar (SNumber Pico_number.zero)
| Scalar (SNumberInterval _) ->
Scalar (SNumberInterval (Pico_number_interval.of_number Pico_number.zero))
| Scalar (SBool _) -> Scalar (SBool false)
| Scalar SUnknownBool -> Scalar (SBool false)
| Scalar (SString _) -> Scalar (SString "")
| Scalar (SNil hint) -> Scalar (SNil hint)
| Scalar (SPointer heap_id) -> Scalar (SPointer heap_id)
| Scalar (SNilPointer hint) -> Scalar (SNilPointer hint)
| Vector vec ->
let scalar_example, _ =
vec |> seq_of_vector |> Seq.uncons |> Option.get
in
normalize_value_for_shape (Scalar scalar_example)
let shape_of_normalized_state (state : state) : state =
let shape = state_map_values normalize_value_for_shape state in
{ shape with vector_size = -1 }
let unpack_state_vector_values (state : state) : vector_value list =
let vector_values = ref [] in
ignore
@@ state_map_values
(fun v ->
match v with
| Vector vec ->
vector_values := vec :: !vector_values;
v
| _ -> v)
state;
List.rev !vector_values
let pack_state_vector_values (state : state) (old_values : vector_value list)
(new_values : value list) : state =
let old_values = ref old_values in
let new_values = ref new_values in
state_map_values
(fun v ->
match v with
| Vector called_value ->
let old_value =
match !old_values with
| h :: t ->
old_values := t;
h
| [] -> assert false
in
let new_value =
match !new_values with
| h :: t ->
new_values := t;
h
| [] -> assert false
in
assert (old_value == called_value);
new_value
| _ -> v)
state
let dedup_vectorized_state (state : state) : state =
let lexicographic_compare_indices (fs : (int -> int -> int) list) :
int -> int -> int =
fun i j ->
List.fold_left
(fun last_result f ->
match last_result with 0 -> f i j | _ -> last_result)
0 fs
in
let old_vector_values = unpack_state_vector_values state in
let sorted_indices =
List.sort_uniq
(lexicographic_compare_indices
(List.map vector_compare_indices old_vector_values))
(List.init state.vector_size (fun i -> i))
|> Array.of_list
in
let new_values =
old_vector_values |> List.map (vector_extract_by_indices sorted_indices)
in
{
(pack_state_vector_values state old_vector_values new_values) with
vector_size = Array.length sorted_indices;
}
let are_all_list_values_equal (l : 'a list) : bool =
match l with
| [] -> true
| first :: rest -> List.for_all (fun v -> v = first) rest
let state_unvectorize_if_possible (state : state) : state =
if state.vector_size = 1 then
state_map_values value_unvectorize_if_possible state
else state
let vectorize_states (states : LazyStateSet.t) : LazyStateSet.t =
let vectorize_values (values : (value * int) list) : value =
let example_value =
Scalar
(values |> List.hd
|> (fun (v, vector_size) -> seq_of_value vector_size v)
|> Seq.uncons |> Option.get |> fst)
in
if can_vectorize_value example_value then
values |> List.to_seq
|> Seq.concat_map (fun (v, vector_size) -> seq_of_value vector_size v)
|> value_of_non_empty_seq
else if are_all_list_values_equal (List.map fst values) then example_value
else failwith "values are not equal and not vectorizable"
in
let vectorize_same_shape_states shape states =
List.iter state_assert_vector_lengths states;
let vectorized_state =
{
(zip_map_state_values vectorize_values shape states) with
vector_size = List.fold_left (fun a c -> a + c.vector_size) 0 states;
}
in
vectorized_state
in
let states_by_shape =
Seq.fold_left
(fun states_by_shape state ->
let shape = shape_of_normalized_state state in
let old_list =
StateMap.find_opt shape states_by_shape |> Option.value ~default:[]
in
let new_list = state :: old_list in
StateMap.add shape new_list states_by_shape)
StateMap.empty
(LazyStateSet.to_normalized_non_deduped_seq states)
in
states_by_shape |> StateMap.to_seq
|> Seq.map (fun (shape, states) ->
let vectorized_state =
vectorize_same_shape_states shape states
|> dedup_vectorized_state |> state_unvectorize_if_possible
in
state_assert_vector_lengths vectorized_state;
vectorized_state)
|> List.of_seq |> LazyStateSet.of_list
let unvectorize_state (state : state) : state Seq.t =
match unpack_state_vector_values state with
| [] ->
assert (state.vector_size = 1);
Seq.return state
| vector_values ->
assert (
List.for_all
(fun vec -> length_of_vector vec = state.vector_size)
vector_values);
vector_values |> List.map seq_of_vector |> zip_seq_list
|> Seq.map (fun scalar_values ->
scalar_values
|> List.map (fun v -> Scalar v)
|> pack_state_vector_values state vector_values)
type prepared_cfg = {
cfg : Ir.cfg;
analyze :
(Block_flow.flow_node -> StateAndMaybeReturnSet.t option) ->
Block_flow.flow_node ->
StateAndMaybeReturnSet.t option;
is_noop : bool;
counter_ref : Perf.timed_counter ref;
}
type fixed_env = {
fun_defs : (string, Ir.fun_def * prepared_cfg) Hashtbl.t;
builtin_funs : (string, builtin_fun) Hashtbl.t;
}
let empty_fixed_env : fixed_env =
{ fun_defs = Hashtbl.create 0; builtin_funs = Hashtbl.create 0 }
let analyze_live_variables cfg =
let module LiveVariableAnalysis =
Graph.Fixpoint.Make
(Flow.G)
(struct
type vertex = Flow.G.E.vertex
type edge = Flow.G.E.t
type g = Flow.G.t
type data = Ir.LocalIdSet.t option
let direction = Graph.Fixpoint.Backward
let equal = Flow.lift_equal Ir.LocalIdSet.equal
let join = Flow.lift_join Ir.LocalIdSet.union
let analyze =
Flow.make_flow_function Liveness.flow_instruction_live_variables
Liveness.flow_terminator_live_variables cfg
end)
in
let g = Flow.flow_graph_of_cfg cfg in
LiveVariableAnalysis.analyze (fun _ -> Some Ir.LocalIdSet.empty) g
type terminator_result =
(* each item corresponds to an input state *)
| Ret of value list option
(* each item might _not_ correspond to an input state *)
| Br of (Ir.label * state) list
let interpret_unary_op_scalar (state : state) (op : string) (v : scalar_value) :
scalar_value =
match (op, v) with
| "-", SNumber v -> SNumber (Pico_number.neg v)
| "not", SBool v -> SBool (not v)
| "not", SUnknownBool -> SUnknownBool
| "#", SString v -> SNumber (Pico_number.of_int @@ String.length v)
| "#", SPointer heap_id -> (
let table = Heap.find heap_id state.heap in
match table with
| HArrayTable items ->
SNumber (Pico_number.of_int @@ ListForArrayTable.length items)
| HUnknownTable -> SNumber (Pico_number.of_int 0)
| _ -> failwith @@ Printf.sprintf "Expected HArrayTable or HUnknownTable")
| op, v ->
failwith
@@ Printf.sprintf "Unsupported unary op: %s %s" op (show_scalar_value v)
let interpret_unary_op state op v =
match v with
| Scalar v -> Scalar (interpret_unary_op_scalar state op v)
| Vector v -> map_vector (interpret_unary_op_scalar state op) v
let rec interpret_binary_op_scalar (l : scalar_value) (op : string)
(r : scalar_value) : scalar_value =
let is_simple_value v =
match v with
| SNumber _ -> true
| SNumberInterval _ -> false
| SBool _ -> true
| SUnknownBool -> false
| SString _ -> true
| SNil _ -> false
| SPointer _ -> false
| SNilPointer _ -> false
in
match (l, op, r) with
| a, "==", b when is_simple_value a && is_simple_value b -> SBool (a = b)
| a, "~=", b when is_simple_value a && is_simple_value b -> SBool (a <> b)
| SNil _, "==", SNil _ -> SBool true
| SNil _, "~=", SNil _ -> SBool false
| a, "==", SNil _ when is_simple_value a -> SBool false
| a, "~=", SNil _ when is_simple_value a -> SBool true
| SNil _, "==", b when is_simple_value b -> SBool false
| SNil _, "~=", b when is_simple_value b -> SBool true
| SPointer l, "==", SPointer r -> SBool (l = r)
| SPointer l, "~=", SPointer r -> SBool (l <> r)
| a, "==", SPointer _ when is_simple_value a -> SBool false
| a, "~=", SPointer _ when is_simple_value a -> SBool true
| SPointer _, "==", b when is_simple_value b -> SBool false
| SPointer _, "~=", b when is_simple_value b -> SBool true
| SNil _, "==", SPointer _ -> SBool false
| SNil _, "~=", SPointer _ -> SBool true
| SPointer _, "==", SNil _ -> SBool false
| SPointer _, "~=", SNil _ -> SBool true
| SNumber l, "+", SNumber r -> SNumber (Pico_number.add l r)
| SNumber l, "-", SNumber r -> SNumber (Pico_number.sub l r)
| SNumber l, "*", SNumber r -> SNumber (Pico_number.mul l r)
| SNumber l, "/", SNumber r -> SNumber (Pico_number.div l r)
| SNumber l, "%", SNumber r -> SNumber (Pico_number.modulo l r)
| SNumber l, "<", SNumber r -> SBool (Int32.compare l r < 0)
| SNumber l, "<=", SNumber r -> SBool (Int32.compare l r <= 0)
| SNumber l, ">", SNumber r -> SBool (Int32.compare l r > 0)
| SNumber l, ">=", SNumber r -> SBool (Int32.compare l r >= 0)
| SNumber l, op, SNumberInterval r ->
let l = SNumberInterval (Pico_number_interval.of_number l) in
let r = SNumberInterval r in
interpret_binary_op_scalar l op r
| SNumberInterval l, op, SNumber r ->
let l = SNumberInterval l in
let r = SNumberInterval (Pico_number_interval.of_number r) in
interpret_binary_op_scalar l op r
| SNumberInterval l, "+", SNumberInterval r ->