Skip to content

Commit

Permalink
Add a builtin that allows specifying which iterate method to use
Browse files Browse the repository at this point in the history
When using the Casette mechanism to intercept calls to _apply,
a common strategy is to rewrite the function argument to properly
consider the context and then falling back to regular _apply.
However, as showin in JuliaLabs/Cassette.jl#146,
this strategy is insufficient as the _apply itself may recurse into
various `iterate` calls which are not properly tracked. This is an
attempt to resolve this problem with a minimal performance penalty.
Attempting to duplicate the _apply logic in julia, would lead to
code that is very hard for inference (and nested Cassette passes to
understand). In contrast, this simply adds a version of _apply that
takes `iterate` as an explicit argument. Cassette and similar tools
can override this argument and provide a function that properly
allows the context to recurse through the iteration, while still
allowing inference to take advantage of the special handling of _apply
for simple cases.

Also change the lowering of splatting to use this new intrinsic directly,
thus fixing #26001.
  • Loading branch information
Keno committed Oct 12, 2019
1 parent 792d3e9 commit a3228b7
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 33 deletions.
1 change: 1 addition & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ else
const UInt = UInt32
end

function iterate end
function Typeof end
ccall(:jl_toplevel_eval_in, Any, (Any, Any),
Core, quote
Expand Down
27 changes: 18 additions & 9 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ end
# refine its type to an array of element types.
# Union of Tuples of the same length is converted to Tuple of Unions.
# returns an array of types
function precise_container_type(@nospecialize(typ), vtypes::VarTable, sv::InferenceState)
function precise_container_type(@nospecialize(itft), @nospecialize(typ), vtypes::VarTable, sv::InferenceState)
if isa(typ, PartialStruct) && typ.typ.name === Tuple.name
return typ.fields
end
Expand Down Expand Up @@ -489,17 +489,24 @@ function precise_container_type(@nospecialize(typ), vtypes::VarTable, sv::Infere
elseif tti0 <: Array
return Any[Vararg{eltype(tti0)}]
else
return abstract_iteration(typ, vtypes, sv)
return abstract_iteration(itft, typ, vtypes, sv)
end
end

# simulate iteration protocol on container type up to fixpoint
function abstract_iteration(@nospecialize(itertype), vtypes::VarTable, sv::InferenceState)
function abstract_iteration(@nospecialize(itft), @nospecialize(itertype), vtypes::VarTable, sv::InferenceState)
if !isdefined(Main, :Base) || !isdefined(Main.Base, :iterate) || !isconst(Main.Base, :iterate)
return Any[Vararg{Any}]
end
iteratef = getfield(Main.Base, :iterate)
stateordonet = abstract_call(iteratef, nothing, Any[Const(iteratef), itertype], vtypes, sv)
if itft === nothing
iteratef = getfield(Main.Base, :iterate)
itft = Const(iteratef)
elseif isa(itft, Const)
iteratef = itft.val
else
return Any[Vararg{Any}]
end
stateordonet = abstract_call(iteratef, nothing, Any[itft, itertype], vtypes, sv)
# Return Bottom if this is not an iterator.
# WARNING: Changes to the iteration protocol must be reflected here,
# this is not just an optimization.
Expand Down Expand Up @@ -543,7 +550,7 @@ function abstract_iteration(@nospecialize(itertype), vtypes::VarTable, sv::Infer
end

# do apply(af, fargs...), where af is a function value
function abstract_apply(@nospecialize(aft), aargtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState,
function abstract_apply(@nospecialize(itft), @nospecialize(aft), aargtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState,
max_methods = sv.params.MAX_METHODS)
aftw = widenconst(aft)
if !isa(aft, Const) && (!isType(aftw) || has_free_typevars(aftw))
Expand All @@ -561,7 +568,7 @@ function abstract_apply(@nospecialize(aft), aargtypes::Vector{Any}, vtypes::VarT
for i = 1:nargs
ctypes´ = []
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
cti = precise_container_type(ti, vtypes, sv)
cti = precise_container_type(itft, ti, vtypes, sv)
if _any(t -> t === Bottom, cti)
continue
end
Expand Down Expand Up @@ -634,7 +641,9 @@ end

function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState, max_methods = sv.params.MAX_METHODS)
if f === _apply
return abstract_apply(argtypes[2], argtypes[3:end], vtypes, sv, max_methods)
return abstract_apply(nothing, argtypes[2], argtypes[3:end], vtypes, sv, max_methods)
elseif f === _apply_iterate
return abstract_apply(argtypes[2], argtypes[3], argtypes[4:end], vtypes, sv, max_methods)
end

la = length(argtypes)
Expand Down Expand Up @@ -662,7 +671,7 @@ function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argt
end
rt = builtin_tfunction(f, argtypes[2:end], sv)
if f === getfield && isa(fargs, Vector{Any}) && length(argtypes) == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
cti = precise_container_type(argtypes[2], vtypes, sv)
cti = precise_container_type(nothing, argtypes[2], vtypes, sv)
idx = argtypes[3].val
if 1 <= idx <= length(cti)
rt = unwrapva(cti[idx])
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ getfield(getfield(Main, :Core), :eval)(getfield(Main, :Core), :(baremodule Compi
using Core.Intrinsics, Core.IR

import Core: print, println, show, write, unsafe_write, stdout, stderr,
_apply, svec, apply_type, Builtin, IntrinsicFunction, MethodInstance, CodeInstance
_apply, _apply_iterate, svec, apply_type, Builtin, IntrinsicFunction, MethodInstance, CodeInstance

const getproperty = getfield
const setproperty! = setfield!
Expand Down
21 changes: 11 additions & 10 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,11 @@ function spec_lambda(@nospecialize(atype), sv::OptimizationState, @nospecialize(
end

# This assumes the caller has verified that all arguments to the _apply call are Tuples.
function rewrite_apply_exprargs!(ir::IRCode, idx::Int, argexprs::Vector{Any}, atypes::Vector{Any})
new_argexprs = Any[argexprs[2]]
new_atypes = Any[atypes[2]]
function rewrite_apply_exprargs!(ir::IRCode, idx::Int, argexprs::Vector{Any}, atypes::Vector{Any}, arg_start::Int)
new_argexprs = Any[argexprs[arg_start]]
new_atypes = Any[atypes[arg_start]]
# loop over original arguments and flatten any known iterators
for i in 3:length(argexprs)
for i in (arg_start+1):length(argexprs)
def = argexprs[i]
def_type = atypes[i]
if def_type isa PartialStruct
Expand Down Expand Up @@ -882,25 +882,26 @@ end

function inline_apply!(ir::IRCode, idx::Int, sig::Signature, params::Params)
stmt = ir.stmts[idx]
while sig.f === Core._apply
while sig.f === Core._apply || sig.f === Core._apply_iterate
arg_start = sig.f === Core._apply ? 2 : 3
atypes = sig.atypes
# Try to figure out the signature of the function being called
# and if rewrite_apply_exprargs can deal with this form
for i = 3:length(atypes)
for i = (arg_start + 1):length(atypes)
# TODO: We could basically run the iteration protocol here
if !is_valid_type_for_apply_rewrite(atypes[i], params)
return nothing
end
end
# Independent of whether we can inline, the above analysis allows us to rewrite
# this apply call to a regular call
ft = atypes[2]
if length(atypes) == 3 && ft isa Const && ft.val === Core.tuple && atypes[3] Tuple
ft = atypes[arg_start]
if length(atypes) == arg_start+1 && ft isa Const && ft.val === Core.tuple && atypes[arg_start+1] Tuple
# rewrite `((t::Tuple)...,)` to `t`
ir.stmts[idx] = stmt.args[3]
ir.stmts[idx] = stmt.args[arg_start+1]
return nothing
end
stmt.args, atypes = rewrite_apply_exprargs!(ir, idx, stmt.args, atypes)
stmt.args, atypes = rewrite_apply_exprargs!(ir, idx, stmt.args, atypes, arg_start)
has_free_typevars(ft) && return nothing
f = singleton_type(ft)
sig = Signature(f, ft, atypes)
Expand Down
2 changes: 2 additions & 0 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ macro _propagate_inbounds_meta()
return Expr(:meta, :inline, :propagate_inbounds)
end

function iterate end

"""
convert(T, x)
Expand Down
2 changes: 1 addition & 1 deletion src/builtin_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ DECLARE_BUILTIN(throw); DECLARE_BUILTIN(is);
DECLARE_BUILTIN(typeof); DECLARE_BUILTIN(sizeof);
DECLARE_BUILTIN(issubtype); DECLARE_BUILTIN(isa);
DECLARE_BUILTIN(_apply); DECLARE_BUILTIN(_apply_pure);
DECLARE_BUILTIN(_apply_latest);
DECLARE_BUILTIN(_apply_latest); DECLARE_BUILTIN(_apply_iterate);
DECLARE_BUILTIN(isdefined); DECLARE_BUILTIN(nfields);
DECLARE_BUILTIN(tuple); DECLARE_BUILTIN(svec);
DECLARE_BUILTIN(getfield); DECLARE_BUILTIN(setfield);
Expand Down
28 changes: 21 additions & 7 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ void STATIC_INLINE _grow_to(jl_value_t **root, jl_value_t ***oldargs, jl_svec_t

static jl_function_t *jl_iterate_func JL_GLOBALLY_ROOTED;

JL_CALLABLE(jl_f__apply)
static jl_value_t *do_apply(jl_value_t *F, jl_value_t **args, uint32_t nargs, jl_value_t *iterate)
{
JL_NARGSV(apply, 1);
jl_function_t *f = args[0];
Expand Down Expand Up @@ -516,10 +516,13 @@ JL_CALLABLE(jl_f__apply)
extra += 1;
}
}
if (extra && jl_iterate_func == NULL) {
jl_iterate_func = jl_get_function(jl_top_module, "iterate");
if (jl_iterate_func == NULL)
jl_undefined_var_error(jl_symbol("iterate"));
if (extra && iterate == NULL) {
if (jl_iterate_func == NULL) {
jl_iterate_func = jl_get_function(jl_top_module, "iterate");
if (jl_iterate_func == NULL)
jl_undefined_var_error(jl_symbol("iterate"));
}
iterate = jl_iterate_func;
}
// allocate space for the argument array and gc roots for it
// based on our previous estimates
Expand Down Expand Up @@ -599,7 +602,7 @@ JL_CALLABLE(jl_f__apply)
assert(extra > 0);
jl_value_t *args[2];
args[0] = ai;
jl_value_t *next = jl_apply_generic(jl_iterate_func, args, 1);
jl_value_t *next = jl_apply_generic(iterate, args, 1);
while (next != jl_nothing) {
roots[stackalloc] = next;
jl_value_t *value = jl_get_nth_field_checked(next, 0);
Expand All @@ -614,7 +617,7 @@ JL_CALLABLE(jl_f__apply)
roots[stackalloc + 1] = NULL;
JL_GC_ASSERT_LIVE(state);
args[1] = state;
next = jl_apply_generic(jl_iterate_func, args, 2);
next = jl_apply_generic(iterate, args, 2);
}
roots[stackalloc] = NULL;
extra -= 1;
Expand All @@ -629,6 +632,16 @@ JL_CALLABLE(jl_f__apply)
return result;
}

JL_CALLABLE(jl_f__apply_iterate)
{
return do_apply(F, args+1, nargs-1, args[0]);
}

JL_CALLABLE(jl_f__apply)
{
return do_apply(F, args, nargs, NULL);
}

// this is like `_apply`, but with quasi-exact checks to make sure it is pure
JL_CALLABLE(jl_f__apply_pure)
{
Expand Down Expand Up @@ -1301,6 +1314,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
// internal functions
jl_builtin_apply_type = add_builtin_func("apply_type", jl_f_apply_type);
jl_builtin__apply = add_builtin_func("_apply", jl_f__apply);
jl_builtin__apply_iterate = add_builtin_func("_apply_iterate", jl_f__apply_iterate);
jl_builtin__expr = add_builtin_func("_expr", jl_f__expr);
jl_builtin_svec = add_builtin_func("svec", jl_f_svec);
add_builtin_func("_apply_pure", jl_f__apply_pure);
Expand Down
9 changes: 6 additions & 3 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2517,11 +2517,13 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
}
}

else if (f == jl_builtin__apply && nargs == 2 && ctx.vaSlot > 0) {
else if (((f == jl_builtin__apply && nargs == 2) ||
(f == jl_builtin__apply_iterate && nargs == 3)) && ctx.vaSlot > 0) {
int arg_start = f == jl_builtin__apply ? 2 : 3;
// turn Core._apply(f, Tuple) ==> f(Tuple...) using the jlcall calling convention if Tuple is the va allocation
if (LoadInst *load = dyn_cast_or_null<LoadInst>(argv[2].V)) {
if (LoadInst *load = dyn_cast_or_null<LoadInst>(argv[arg_start].V)) {
if (load->getPointerOperand() == ctx.slots[ctx.vaSlot].boxroot && ctx.argArray) {
Value *theF = maybe_decay_untracked(boxed(ctx, argv[1]));
Value *theF = maybe_decay_untracked(boxed(ctx, argv[arg_start-1]));
Value *nva = emit_n_varargs(ctx);
#ifdef _P64
nva = ctx.builder.CreateTrunc(nva, T_int32);
Expand Down Expand Up @@ -7278,6 +7280,7 @@ static void init_julia_llvm_env(Module *m)
builtin_func_map[jl_f_typeassert] = jlcall_func_to_llvm("jl_f_typeassert", &jl_f_typeassert, m);
builtin_func_map[jl_f_ifelse] = jlcall_func_to_llvm("jl_f_ifelse", &jl_f_ifelse, m);
builtin_func_map[jl_f__apply] = jlcall_func_to_llvm("jl_f__apply", &jl_f__apply, m);
builtin_func_map[jl_f__apply_iterate] = jlcall_func_to_llvm("jl_f__apply_iterate", &jl_f__apply_iterate, m);
builtin_func_map[jl_f__apply_pure] = jlcall_func_to_llvm("jl_f__apply_pure", &jl_f__apply_pure, m);
builtin_func_map[jl_f__apply_latest] = jlcall_func_to_llvm("jl_f__apply_latest", &jl_f__apply_latest, m);
builtin_func_map[jl_f_throw] = jlcall_func_to_llvm("jl_f_throw", &jl_f_throw, m);
Expand Down
2 changes: 1 addition & 1 deletion src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -2096,7 +2096,7 @@
(tuple-wrap (cdr a) '())))
(tuple-wrap (cdr a) (cons x run))))))
(expand-forms
`(call (core _apply) ,f ,@(tuple-wrap argl '())))))
`(call (core _apply_iterate) (top iterate) ,f ,@(tuple-wrap argl '())))))

((and (eq? (identifier-name f) '^) (length= e 4) (integer? (cadddr e)))
(expand-forms
Expand Down
3 changes: 2 additions & 1 deletion src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ static void *const _tags[] = {
// some Core.Builtin Functions that we want to be able to reference:
&jl_builtin_throw, &jl_builtin_is, &jl_builtin_typeof, &jl_builtin_sizeof,
&jl_builtin_issubtype, &jl_builtin_isa, &jl_builtin_typeassert, &jl_builtin__apply,
&jl_builtin__apply_iterate,
&jl_builtin_isdefined, &jl_builtin_nfields, &jl_builtin_tuple, &jl_builtin_svec,
&jl_builtin_getfield, &jl_builtin_setfield, &jl_builtin_fieldtype, &jl_builtin_arrayref,
&jl_builtin_const_arrayref, &jl_builtin_arrayset, &jl_builtin_arraysize,
Expand Down Expand Up @@ -109,7 +110,7 @@ static htable_t fptr_to_id;
// This is a manually constructed dual of the fvars array, which would be produced by codegen for Julia code, for C.
static const jl_fptr_args_t id_to_fptrs[] = {
&jl_f_throw, &jl_f_is, &jl_f_typeof, &jl_f_issubtype, &jl_f_isa,
&jl_f_typeassert, &jl_f__apply, &jl_f__apply_pure, &jl_f__apply_latest, &jl_f_isdefined,
&jl_f_typeassert, &jl_f__apply, &jl_f__apply_iterate, &jl_f__apply_pure, &jl_f__apply_latest, &jl_f_isdefined,
&jl_f_tuple, &jl_f_svec, &jl_f_intrinsic_call, &jl_f_invoke_kwsorter,
&jl_f_getfield, &jl_f_setfield, &jl_f_fieldtype, &jl_f_nfields,
&jl_f_arrayref, &jl_f_const_arrayref, &jl_f_arrayset, &jl_f_arraysize, &jl_f_apply_type,
Expand Down
6 changes: 6 additions & 0 deletions test/compiler/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,9 @@ end
# Warm up
f_dict_hash_alloc(); g_dict_hash_alloc();
@test (@allocated f_dict_hash_alloc()) == (@allocated g_dict_hash_alloc())

let io = IOBuffer()
# Test for the f(args...) = g(args...) generic codegen optimization
code_llvm(io, Base.vect, Tuple{Vararg{Union{Float64, Int64}}})
@test !occursin("__apply", String(take!(io)))
end

0 comments on commit a3228b7

Please sign in to comment.