diff --git a/lib/common/Variable.ml b/lib/common/Variable.ml index 628e719..575db06 100644 --- a/lib/common/Variable.ml +++ b/lib/common/Variable.ml @@ -65,6 +65,20 @@ let are_flags_included v1 v2 = Flags.subset v1.flags v2.flags let merge_flags v1 v2 gen = gen (Flags.union v1.flags v2.flags) +let rec find_most_general_rec flags v = function + | h :: t -> + let flags_h = get_flags h in + let new_flags = Flags.union flags flags_h in + if Flags.equal new_flags flags_h then find_most_general_rec new_flags (Some h) t + else if Flags.equal new_flags flags then find_most_general_rec new_flags v t + else find_most_general_rec new_flags None t + | [] -> + match v with + | Some v -> Either.Left v + | None -> Either.Right flags + +let find_most_general = find_most_general_rec Flags.empty None + module Map = CCMap.Make (struct type nonrec t = t let compare = compare diff --git a/lib/common/Variable.mli b/lib/common/Variable.mli index b3b3c45..d1e6e0a 100644 --- a/lib/common/Variable.mli +++ b/lib/common/Variable.mli @@ -34,6 +34,8 @@ val are_flags_included : t -> t -> bool *) val merge_flags : t -> t -> ( Flags.t -> t) -> t +val find_most_general : t list -> (t, Flags.t) Either.t + module Map : CCMap.S with type key = t module HMap : CCHashtbl.S with type key = t module Set : CCSet.S with type elt = t diff --git a/lib/unification/Syntactic.ml b/lib/unification/Syntactic.ml index 08e2299..bbc9d15 100644 --- a/lib/unification/Syntactic.ml +++ b/lib/unification/Syntactic.ml @@ -129,16 +129,20 @@ let rec occur_check env : return = (Type.tuple (Env.tyenv env) Type.NSet.empty)(*unit*)) l stack in - collapse stack tl + collapse stack (v::tl) ) | _ -> Some stack in match collapse Stack.empty l with | None -> FailedOccurCheck env | Some stack -> - let v = List.hd l in - List.iter (fun u -> Env.add env u (Type.var (Env.tyenv env) v)) (List.tl l); - Env.remove env v; + (match Variable.find_most_general l with + | Either.Right flags -> + let v = Env.gen flags env in + List.iter (fun u -> Env.add env u (Type.var (Env.tyenv env) v)) l + | Left v -> + List.iter (fun u -> Env.add env u (Type.var (Env.tyenv env) v)) l; + Env.remove env v); let* () = process_stack env stack in occur_check env diff --git a/test/unit_tests/test_unification.ml b/test/unit_tests/test_unification.ml index 41d8357..0dc45f4 100644 --- a/test/unit_tests/test_unification.ml +++ b/test/unit_tests/test_unification.ml @@ -70,6 +70,8 @@ let pos_tests = [ "'a -> 'b", "'a * 'b" ; (* Bug non-arrow *) "((float -> int) * float * int, int -> 'b) t", "('b * int, int -> 'b) t"; + (* Bug occur check collapse *) + "'a * 'a -> 'a", "('c * 'd * 'e * 'f * (('c * 'e -> 'a)) * (('d * 'f -> 'b)) -> ('a * 'b))" ] let neg_tests = [