# coding=utf-8
# Copyright 2022 The Fiddle-Config Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides utilities for transforming builder functions into `fdl.Config`s.
This module defines the `auto_config` function (and associated helpers), which
can be used to convert an existing function that creates an object graph into a
function that creates a graph of `Config` and `Partial` objects. When the
resulting graph of `Config` and `Partial` objects is built via `fdl.build()`, it
will yield same object graph as the original function.
"""
import ast
import builtins
import dataclasses
import functools
import inspect
import linecache
import textwrap
import types
from typing import Any, Callable, Optional, Type, TypeVar, cast
from fiddle._src import arg_factory
from fiddle._src import building
from fiddle._src import casting as cast_lib
from fiddle._src import config
from fiddle._src import copying
from fiddle._src import mutate_buildable
from fiddle._src import partial
from fiddle._src.experimental import auto_config_policy
from fiddle._src.experimental import daglish_legacy
import libcst as cst
_CALL_HANDLER_ID = '__auto_config_call_handler__'
_ATTR_LOAD_HANDLER_ID = '__auto_config_attr_load_handler__'
_ATTR_SAVE_HANDLER_ID = '__auto_config_attr_save_handler__'
_ATTR_SAVE_TEMP_VAR_ID = '_attr_save_temp'
_CLOSURE_WRAPPER_ID = '__auto_config_closure_wrapper__'
_EMPTY_ARGUMENTS = ast.arguments(
posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]
)
_BUILTINS = frozenset([
builtin
for builtin in builtins.__dict__.values()
if inspect.isroutine(builtin) or inspect.isclass(builtin)
])
_GenericCallable = TypeVar('_GenericCallable', bound=Callable[..., Any])
@dataclasses.dataclass(frozen=True)
class AutoConfigClassMethod:
"""A wrapper for auto_config'd class methods."""
func: Callable[..., Any]
always_inline: bool
def __get__(self, obj, objtype=None):
return AutoConfig(
func=types.MethodType(self.func, objtype),
buildable_func=types.MethodType(self.func, objtype),
always_inline=self.always_inline,
)
@property
def __wrapped__(self):
return self.func
[docs]
@dataclasses.dataclass(frozen=True)
class AutoConfig:
"""A function wrapper for auto_config'd functions.
In order to support auto_config'ing @classmethod's, we need to customize the
descriptor protocol for the auto_config'd function. This simple wrapper type
is designed to look like a `functool.wraps` wrapper, but implements custom
behavior for bound methods.
"""
func: Callable[..., Any]
buildable_func: Callable[..., config.Buildable]
always_inline: bool
@property
def nowrap(self):
return True # Tells Flax not to decorate this object, for classmethods.
def __post_init__(self):
# Must copy-over to correctly implement "functools.wraps"-like
# functionality.
for name in (
'__module__',
'__name__',
'__qualname__',
'__doc__',
'__annotations__',
):
try:
value = getattr(self.func, name)
except AttributeError:
pass
else:
object.__setattr__(self, name, value)
def __call__(self, /, *args, **kwargs) -> Any:
return self.func(*args, **kwargs)
def as_buildable(self, /, *args, **kwargs) -> config.Buildable:
return self.buildable_func(*args, **kwargs)
def __get__(self, obj, objtype=None):
# pytype: disable=attribute-error
return AutoConfig(
func=self.func.__get__(obj, objtype),
buildable_func=self.buildable_func.__get__(obj, objtype),
always_inline=self.always_inline,
)
# pytype: enable=attribute-error
@property
def __wrapped__(self):
return self.func
def __getattr__(self, name):
# Pass through extra things on the thing we wrapped. We use
# super().__getattribute__('func') here to avoid an infinite recursion.
return getattr(super().__getattribute__('func'), name)
class UnsupportedLanguageConstructError(SyntaxError):
pass
class _AutoConfigNodeTransformer(ast.NodeTransformer):
"""A NodeTransformer that adds the auto-config call handler into an AST."""
def __init__(
self,
source: str,
filename: str,
line_number: int,
allow_control_flow=False,
):
"""Initializes the auto config node transformer instance.
Args:
source: The source code of the node that will be transformed by this
instance. This is used for better error reporting.
filename: The filename `source` is from.
line_number: The line number of `source` within `filename`.
allow_control_flow: Whether to permit control flow constructs (loops,
conditionals, comprehensions, etc). By default, this is `False`, and
control flow constructs will cause an
`UnsupportedLanguageConstructError` to be raised.
"""
self._lines = source.splitlines()
self._filename = filename
self._line_number = line_number
self._allow_control_flow = allow_control_flow
self._function_def_depth = 0
self._temp_var_count = 0
def _location_for(self, node: ast.AST):
line_number = self._line_number + node.lineno - 1 # pytype: disable=attribute-error
line = self._lines[node.lineno - 1] # pytype: disable=attribute-error
return (self._filename, line_number, node.col_offset, line) # pytype: disable=attribute-error
def _handle_control_flow(self, node: ast.AST, activatable: bool = False):
if self._allow_control_flow and activatable:
return self.generic_visit(node)
msg = f'Control flow ({type(node).__name__}) is unsupported by auto_config.'
raise UnsupportedLanguageConstructError(msg, self._location_for(node))
def _generic_visit_inside_function(self, node):
try:
self._function_def_depth += 1
return self.generic_visit(node)
finally:
self._function_def_depth -= 1
def _validate_decorator_ordering(self, node: ast.FunctionDef):
"""Validates that decorators are applied in the right order.
This is done on a best effort basis to catch cases where @classmethod or
@staticmethod are applied on top of @auto_config.
Args:
node: The `ast.FunctionDef` node to validate decorators for.
"""
decorator_list = []
for decorator in node.decorator_list:
if isinstance(decorator, ast.Call):
decorator = decorator.func
if isinstance(decorator, ast.Attribute):
decorator = decorator.attr
if isinstance(decorator, ast.Name):
decorator = decorator.id
decorator_list.append(decorator)
try:
auto_config_index = decorator_list.index('auto_config')
except ValueError:
# Probably auto_config wasn't called as a decorator. Another alternative
# is that the auto_config function was assigned to a variable with a
# different name before being applied as a decorator...
return
for decorator in decorator_list[:auto_config_index]:
if decorator in ('classmethod', 'staticmethod'):
raise AssertionError(
f'@{decorator} placed above @auto_config on function {node.name} '
f'at {self._filename}:{self._line_number}. Reorder decorators so '
f'that @auto_config is placed above @{decorator}.'
)
# pylint: disable=invalid-name
def visit_Call(self, node: ast.Call):
return ast.Call(
func=ast.Name(id=_CALL_HANDLER_ID, ctx=ast.Load()),
args=[node.func, *(self.visit(arg) for arg in node.args)],
keywords=[self.visit(keyword) for keyword in node.keywords],
)
def visit_Attribute(self, node: ast.Attribute):
if isinstance(node.ctx, ast.Load):
return ast.Call(
func=ast.Name(id=_ATTR_LOAD_HANDLER_ID, ctx=ast.Load()),
args=[self.visit(node.value), ast.Constant(value=node.attr)],
keywords=[],
)
return self.generic_visit(node)
def visit_Assign(self, node: ast.Assign):
"""Handler assignment transformation."""
def make_expr_call(obj, attr, value):
if isinstance(obj, ast.Attribute):
obj = self.visit_Attribute(obj)
return ast.Expr(
ast.Call(
func=ast.Name(id=_ATTR_SAVE_HANDLER_ID, ctx=ast.Load()),
args=[obj, ast.Constant(value=attr), value], # pytype: disable=missing-parameter
keywords=[],
)
)
node.value = self.visit(node.value)
if len(node.targets) == 1 and isinstance(node.targets[0], ast.Attribute):
return make_expr_call(
node.targets[0].value, node.targets[0].attr, node.value
)
# Avoid creating temp var for single target ast.Assign expression,
# like `a.b = c.d.e`, to improve simplicity.
# For multiple targets ast.Assign expression, temporary variables will be
# created to facilitate set attribute validation.
# For example, `a.b = c.d = foo` will be transformed into:
# ```
# temp_var_0 = temp_var_1 = foo
# __auto_config_attr_save_handler__(a, b, temp_var_0)
# __auto_config_attr_save_handler__(c, d, temp_var_1)
# ```
def make_temp_var():
temp_var = ast.Name(
id=f'{_ATTR_SAVE_TEMP_VAR_ID}_{self._temp_var_count}',
ctx=ast.Store(),
)
return temp_var
transformed_nodes = []
for target in node.targets:
if isinstance(target, ast.Tuple) or isinstance(target, ast.List):
new_elts = []
for elt in target.elts:
if isinstance(elt, ast.Attribute):
temp_var = make_temp_var()
new_elts.append(temp_var)
expr_node = make_expr_call(elt.value, elt.attr, temp_var)
transformed_nodes.append(expr_node)
self._temp_var_count += 1
else:
new_elts.append(elt)
target.elts = new_elts
elif isinstance(target, ast.Attribute):
temp_var = make_temp_var()
expr_node = make_expr_call(target.value, target.attr, temp_var)
transformed_nodes.append(expr_node)
node.targets[node.targets.index(target)] = temp_var
self._temp_var_count += 1
elif isinstance(target, ast.Subscript):
target.value = self.visit(target.value)
target.slice = self.visit(target.slice)
elif isinstance(target, ast.Starred):
# TODO(b/288479702): Add validation when target is ast.Starred.
pass
elif isinstance(target, ast.Name):
pass
else:
raise NotImplementedError(
f'Cannot handle Assign statement with {target} as target.'
)
transformed_nodes.insert(0, node)
return transformed_nodes
def visit_For(self, node: ast.For):
return self._handle_control_flow(node, activatable=True)
def visit_While(self, node: ast.While):
return self._handle_control_flow(node, activatable=True)
def visit_If(self, node: ast.If):
return self._handle_control_flow(node, activatable=True)
def visit_IfExp(self, node: ast.IfExp):
return self._handle_control_flow(node, activatable=True)
def visit_ListComp(self, node: ast.ListComp):
return self._handle_control_flow(node, activatable=True)
def visit_SetComp(self, node: ast.SetComp):
return self._handle_control_flow(node, activatable=True)
def visit_DictComp(self, node: ast.DictComp):
return self._handle_control_flow(node, activatable=True)
def visit_GeneratorExp(self, node: ast.GeneratorExp):
return self._handle_control_flow(node, activatable=True)
def visit_Try(self, node: ast.Try):
return self._handle_control_flow(node)
def visit_Raise(self, node: ast.Try): # pyrefly: ignore[bad-override]
return self._handle_control_flow(node, activatable=True)
def visit_With(self, node: ast.With):
return self._handle_control_flow(node)
def visit_Yield(self, node: ast.Yield):
return self._handle_control_flow(node)
def visit_YieldFrom(self, node: ast.YieldFrom):
return self._handle_control_flow(node)
def visit_FunctionDef(self, node: ast.FunctionDef):
"""Transforms a FunctionDef node."""
if self._function_def_depth > 0:
msg = 'Nested function definitions are not supported by auto_config.'
raise UnsupportedLanguageConstructError(msg, self._location_for(node))
else:
self._validate_decorator_ordering(node)
# Backup decorator_list because we don't want to transform anything
# in decorators.
decorator_list = node.decorator_list
node.decorator_list = []
node = self._generic_visit_inside_function(node)
node.decorator_list = decorator_list
return node
def visit_Lambda(self, node: ast.Lambda):
if self._function_def_depth > 0:
msg = 'Lambda definitions are not supported by auto_config.'
raise UnsupportedLanguageConstructError(msg, self._location_for(node))
else:
return self._generic_visit_inside_function(node)
def visit_ClassDef(self, node: ast.ClassDef):
msg = 'Class definitions are not supported by auto_config.'
raise UnsupportedLanguageConstructError(msg, self._location_for(node))
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
msg = 'Async function definitions are not supported by auto_config.'
raise UnsupportedLanguageConstructError(msg, self._location_for(node))
# pylint: enable=invalid-name
def _contains_buildable(structure):
"""Returns `True` if `structure` contains a `fdl.Buildable`."""
contains_buildable = False
def traverse(unused_path, value):
nonlocal contains_buildable
if isinstance(value, config.Buildable):
contains_buildable = True
return # Stop traversal.
else:
yield # Continue traversal.
daglish_legacy.traverse_with_path(traverse, structure)
return contains_buildable
def _wrap_ast_for_fn_with_closure_vars(
module: ast.Module,
fn: types.FunctionType,
) -> ast.Module:
"""Wraps `module.body` in a function that defines closure variables for `fn`.
If `fn` has any free variables (i.e., it's `__code__.co_freevars` is not
empty), we want to make sure that compiling its AST (assumed to be in the body
of `module`) will create the same set of free variables in the resulting code
object. However, by default this won't happen, since we would be compiling
`fn`'s AST in the absence of its original context (e.g., just compiling a
nested function, and not the containing one).
To work around this issue, this function wraps `module.body` in another
`FunctionDef` that defines dummy variables corresponding to `fn`'s free
variables. This causes the subsequent compile step to create the right set of
free variables, and allows us to use `fn.__closure__` when creating a
new function object via `types.FunctionType`.
We also add <_CALL_HANDLER_ID> as a final dummy variable, and append its value
(the call handler) to `fn.__closure__` when creating the new function object.
Effectively, this wrapping looks like the following Python code:
def __auto_config_closure_wrapper__():
closure_var_1 = None
closure_var_2 = None
...
<_CALL_HANDLER_ID> = None
def fn(...): # Or some expression involving a lambda.
... # Contains references to the closure variables.
Args:
module: An `ast.Module` object whose body contains the function definition
for `fn` (e.g., as an `ast.FunctionDef` or `ast.Lambda`).
fn: The function to create dummy closure variables for (assumed to
correspond to the body of `module`).
Returns:
A new `ast.Module` containing an additional wrapper `ast.FunctionDef` that
defines dummy closure variables.
"""
ast_name = lambda name: ast.Name(id=name, ctx=ast.Store())
ast_none = ast.Constant(value=None)
closure_var_definitions = [
ast.Assign(targets=[ast_name(var_name)], value=ast_none)
for var_name in fn.__code__.co_freevars
+ (_CALL_HANDLER_ID, _ATTR_LOAD_HANDLER_ID, _ATTR_SAVE_HANDLER_ID)
]
wrapper_module = ast.Module(
body=[
ast.FunctionDef( # pytype: disable=missing-parameter
name=_CLOSURE_WRAPPER_ID,
args=_EMPTY_ARGUMENTS,
body=[
*closure_var_definitions,
*module.body,
],
decorator_list=[],
)
],
type_ignores=[],
)
wrapper_module = ast.fix_missing_locations(wrapper_module)
return wrapper_module
def _find_function_code(code: types.CodeType, fn_name: str):
"""Finds the code object within `code` corresponding to `fn_name`."""
code = [ # pyrefly: ignore[bad-assignment]
const
for const in code.co_consts
if inspect.iscode(const) and const.co_name == fn_name
]
assert len(code) == 1, f"Couldn't find function code for {fn_name!r}." # pyrefly: ignore[bad-argument-type]
return code[0] # pyrefly: ignore[bad-index]
def _unwrap_code_for_fn(code: types.CodeType, fn: types.FunctionType):
"""Unwraps `code` to find the code object for `fn`.
This function assumes `code` is the result of compiling an `ast.Module`
returned by `_wrap_node_for_fn_with_closure_vars`.
Args:
code: A code object containing code for `fn`.
fn: The function to find a code object for within `code`.
Returns:
The code object corresponding to `fn`.
"""
code = _find_function_code(code, _CLOSURE_WRAPPER_ID)
code = _find_function_code(code, fn.__name__)
return code
def _make_closure_cell(contents):
"""Returns `types.CellType(contents)`."""
if hasattr(types, 'CellType'):
# `types.CellType` added in Python 3.8.
return types.CellType(contents) # pytype: disable=wrong-arg-count
else:
# For earlier versions of Python, build a dummy function to get CellType.
dummy_fn = lambda: contents
cell_type = type(dummy_fn.__closure__[0]) # pyrefly: ignore[unsupported-operation]
return cell_type(contents)
def _maybe_as_arg_factory(arg_factory_cls, arg):
"""Converts an argument of an arg_factory partial() to Fiddle buildables.
In normal Python, one expresses arg factories like,
my_fn = arg_factory.partial(fn, foo=foo_factory, bar=bar_factory)
where `foo_factory` produces a `foo` and `bar_factory` produces a `bar`. These
are called each time `my_fn` is called.
The Fiddle configuration for `my_fn`, on the other hand, looks like,
my_fn_config = fdl.Partial(fn, foo=fdl.ArgFactory(foo_factory),
bar=fdl.ArgFactory(bar_factory))
Therefore, we need to wrap `foo_factory` and `bar_factory` in
`partial.ArgFactory`. Or, if they are already callable sub-configs, then we
wrap them in ArgFactory.
If `foo_factory` or `bar_factory` is not a callable or fdl.Partial, then we
raise an error. It's techincally possible to pass `foo_factory` as a
fdl.Config object that, when called, returns another fucntion, but this is
most likely a mistake in configuration, so we don't allow it.
Args:
arg_factory_cls: The type to use when creating the ArgFactory (normally this
will just be `fdl.ArgFactory`, but can potentially be customized).
arg: Intermediate value passed to `arg_factory.partial`.
Returns:
ArgFactory version of a configuration or callable.
"""
if isinstance(arg, partial.Partial):
return cast_lib.cast(arg_factory_cls, arg)
elif callable(arg):
return arg_factory_cls(arg)
else:
raise ValueError(
"Couldn't figure out how to handle arg_factory argument; please "
f'bind any constant args with a nested functools.partial. Arg: {arg!r}'
)
def _make_partial(partial_cls, buildable_or_callable, *args, **kwargs):
"""Makes a fdl.Partial, but calling appropriate APIs if casting is required.
Args:
partial_cls: The type to use when creating the Partial (normally this will
just be `fdl.Partial`, but can potentially be customized).
buildable_or_callable: Callable or existing configuration object to update.
*args: Positional arguments, only supported when `config_or_callable` is a
Partial already.
**kwargs: Keyword arguments.
Returns:
New callable.
"""
if isinstance(buildable_or_callable, partial.Partial):
if args:
# Note: this can cause an issue even in when not chained, if the built
# functools.partial object is called with arguments. We may later choose
# to raise exceptions in those cases. For this case however, it's hard to
# define any reasonable behavior, so we always error.
raise ValueError(
'For chained functools.partial calls inside auto_config, e.g. '
'functools.partial(functools.partial(foo, ...), ...), only keyword '
f'arguments can be supplied to the outer call. Got: {args!r}'
)
return copying.copy_with(buildable_or_callable, **kwargs)
else:
return partial_cls(buildable_or_callable, *args, **kwargs)
def exempt(fn_or_cls: _GenericCallable) -> _GenericCallable:
"""Wrap a callable so that it's exempted from auto_config.
This can be used either as a decorator to exempt a function, or used inside
an auto_config function to inline exempt certain calls to a function.
During auto_config transformation, exempted function calls will be evaluated
normally rather than turned into a config object. For example::
@exempt
def my_square(x): return x * x
@auto_config
def build_model():
return Model(a=np.square(3), b=exempt(np.square)(3), c=my_square(3))
config = build_model.as_buildable()
assert config.a == fdl.Config(np.square, 3)
assert config.b == config.c == 9
Args:
fn_or_cls: Any callable.
Returns:
A wrapped version of the same callable that will not be transformed to
config if called inside an auto_config function.
"""
return AutoConfig( # pyrefly: ignore[bad-return]
func=fn_or_cls, buildable_func=fn_or_cls, always_inline=True
)
def inlined_partial(
auto_config_fn: AutoConfig, *args, **kwargs
) -> functools.partial:
"""Creates a `fdl.Partial` from an auto_config function.
When `inlined_partial` is called inside another auto_config function, the
config returned by `auto_config_fn.as_buildable(*args, **kwargs)` is cast into
a `fdl.Partial`, meaning the resulting "internals" of the config structure
remain accessible/modifiable during configuration.
When called outside of an auto_config context, this has the same behavior
as `functools.partial`.
For example:
@auto_config.auto_config
def make_object(arg1):
return SomeObject(arg1=arg1, arg2=2)
@auto_config.auto_config
def use_inlined_partial():
return auto_config.inlined_partial(make_object, arg1=1)
assert use_inlined_partial.as_buildable() == fdl.Partial(
SomeObject, arg1=1, arg2=2)
In comparison (retaining `make_object` from above):
@auto_config.auto_config
def use_functools_partial():
return functools.partial(make_object, arg1=1)
assert use_functools_partial.as_buildable() == fdl.Partial(
make_object, arg1=1)
Support for this is special cased as part of `auto_config`. Note that
`auto_config_fn.as_buildable()` must return a single `fdl.Config` object;
return values of containers or other `fdl.Partial`s are not supported.
Args:
auto_config_fn: The auto_config function to create a corresponding partial
from.
*args: Positional arguments to forward to `auto_config_fn`.
**kwargs: Keyword arguments to forward to `auto_config_fn`.
Returns:
Inside an auto_config function, returns a `fdl.Partial` obtained from
casting the return value of `auto_config_fn.as_buildable()`. Outside an
auto_config function, returns a standard `functools.partial` instance.
"""
return functools.partial(auto_config_fn, *args, **kwargs)
@dataclasses.dataclass(frozen=True)
class ConfigTypes:
config_cls: Type[config.Config] = config.Config
partial_cls: Type[partial.Partial] = partial.Partial
arg_factory_cls: Type[partial.ArgFactory] = partial.ArgFactory
[docs]
def auto_config(
fn=None,
*,
experimental_allow_dataclass_attribute_access=False,
experimental_allow_control_flow: bool = False,
experimental_always_inline: Optional[bool] = None,
experimental_exemption_policy: Optional[auto_config_policy.Policy] = None,
experimental_config_types: ConfigTypes = ConfigTypes(),
experimental_result_must_contain_buildable: bool = True,
) -> Any: # TODO(b/272377821): More precise return type.
"""Rewrites the given function to make it generate a ``Config``.
This function creates a new function from ``fn`` by rewriting its AST
(abstract syntax tree), replacing all ``Call`` nodes with a custom call
handler. When the rewritten function is run, the call handler intercepts calls
and applies the following rules:
- Calls to builtins, methods, callables without an inferrable signature,
callables wrapped by `auto_config.exempt`, or other functions that have
been ``auto_config``'ed take place as usual.
- Calls to ``functools.partial`` are replaced by calling ``fdl.Partial`` with
the same arguments;
- All other calls are replaced by calling ``fdl.Config`` with the arguments
that would have been passed to the called function or class.
This function may be used standalone or as a decorator. The returned function
is simply a wrapper around ``fn``, but with an additional ``as_buildable``
attribute containing the rewritten function. For example::
def build_model():
return Sequential([
Dense(num_units=128, activation=relu),
Dense(num_units=128, activation=relu),
Dense(num_units=1, activation=None),
])
config = auto_config(build_model).as_buildable()
The resulting ``config`` is equivalent to the following "manually" constructed
configuration graph::
fdl.Config(Sequential, layers=[
fdl.Config(Dense, num_units=128, activation=relu),
fdl.Config(Dense, num_units=128, activation=relu),
fdl.Config(Dense, num_units=1, activation=None),
])
This can then be built with ``fdl.build(config)``. Without modification, this
will result in the same model as just calling ``build_model()`` directly.
However, ``config`` permits changes to the model hyperparameters, for
example::
config.layers[0].num_units = 64
config.layers[0].activation = 'elu'
config.layers[1].num_units = 64
config.layers[1].activation = 'elu'
modified_model = fdl.build(config)
Currently, control flow is not supported by default in ``auto_config``.
Experimental support for control flow can be enabled using the
``experimental_allow_control_flow`` argument. If enabled, control flow
constructs may be used within the function to construct the resulting config
(for example, a ``for`` loop could be used to build a list of layers). Control
flow is never encoded directly as part of the resulting ``fdl.Config`` (for
example, there is no ``fdl.Config`` that will correspond to a conditional or
loop). While many simple constructs (``for _ in range(10)`` etc) work, there
will also likely be surprising behavior in some circumstances (for example,
using ``itertools`` functions in conjunction with a loop will not work, since
the calls to ``itertools`` functions will be turned into ``fdl.Config``
objects).
Using ``@auto_config`` is compatible with both ``@staticmethod`` and
``@classmethod``, however the ``@auto_config`` decorator must appear above the
``@classmethod`` or ``@staticmethod`` in the decorator list.
Args:
fn: The function to create a config-generating function from.
experimental_allow_dataclass_attribute_access: Whether to allow attribute
access on dataclasses within auto_config. Note that access to dataclass
attribute is transformed into access to fdl.Config attributes in the
as_buildable path.
experimental_allow_control_flow: Whether to allow control flow constructs in
``fn``. By default, control flow constructs will cause an
``UnsupportedLanguageConstructError`` to be thrown.
experimental_always_inline: If true, this function (when called in an
``auto_config`` context) will always be ``inline``'d in-place. See the
documentation on ``inline`` for an example. The default (if unspecified)
is currently ``True``.
experimental_exemption_policy: An optional policy to control which function
calls within the body of ``fn`` should be turned into ``fdl.Config``'s and
which ones should simply be executed normally during the ``as_buildable``
interpretation of ``fn``. This predicate should return ``True`` if the
given callable should be exempted from auto-configuration.
experimental_config_types: A ``ConfigTypes`` instance containing the types
to use when generating configs. By default, this just supplies the
standard Fiddle types ()``fdl.Config``, ``fdl.Partial``, and
``fdl.ArgFactory``), but projects with custom subclasses can use this to
override the default. This is experimental and may be removed in the
future.
experimental_result_must_contain_buildable: If true, then raise an error if
`fn.as_buildable` returns a result that does not contain any `Buildable`
values -- e.g., if it returns an empty dict.
Returns:
A wrapped version of ``fn``, but with an additional ``as_buildable``
attribute containing the rewritten function.
"""
if experimental_always_inline is None:
experimental_always_inline = True
if experimental_exemption_policy is None:
experimental_exemption_policy = auto_config_policy.latest
def auto_config_call_handler(fn_or_cls, /, *args, **kwargs):
"""Handles calls in auto_config'ed functions.
This intercepts calls in an auto-configed function, and determines whether
the called `fn_or_cls` should be wrapped in a `Config` or `Partial`. If
`fn_or_cls` is `functools.partial`, the call will instead be converted into
a call to Fiddle's `Partial`. If it is "auto-config eligible" (see
`experimental_custom_call_policy`), then a `Config` will be create for
`fn_or_cls` with the provided arguments. Otherwise, `fn_or_cls` is called
directly.
Args:
fn_or_cls: The function or class being called.
*args: The positional arguments with which `fn_or_cls` is being called.
**kwargs: The keyword arguments with which `fn_or_cls` is being called.
Returns:
Depending on `fn_or_cls`, either `Partial`, a `Config`, or the result of
calling `fn_or_cls` with the provided arguments.
"""
if isinstance(fn_or_cls, AutoConfig) and fn_or_cls.always_inline:
return fn_or_cls.as_buildable(*args, **kwargs)
partial_cls = experimental_config_types.partial_cls
if fn_or_cls is functools.partial:
return _make_partial(partial_cls, args[0], *args[1:], **kwargs)
elif fn_or_cls is arg_factory.partial:
arg_factory_cls = experimental_config_types.arg_factory_cls
return _make_partial(
partial_cls,
args[0],
*[_maybe_as_arg_factory(arg_factory_cls, arg) for arg in args[1:]],
**{
name: _maybe_as_arg_factory(arg_factory_cls, arg)
for name, arg in kwargs.items()
},
)
elif fn_or_cls is inlined_partial:
auto_config_fn, partial_args = args[0], args[1:]
if not isinstance(auto_config_fn, AutoConfig):
raise ValueError(
'inlined_partial should only be applied to auto_config functions,'
f' received: {auto_config_fn}'
)
cfg = auto_config_fn.as_buildable(*partial_args, **kwargs)
if not isinstance(cfg, config.Config):
raise ValueError(
'inlined_partial should only be applied to auto_config functions'
f' that create a single top-level Config, received: {cfg}'
)
return cast_lib.cast(partial_cls, cfg)
if fn_or_cls is exempt:
return fn_or_cls(*args, **kwargs)
if experimental_exemption_policy(fn_or_cls):
return fn_or_cls(*args, **kwargs)
return experimental_config_types.config_cls(fn_or_cls, *args, **kwargs)
def auto_config_attr_load_handler(value, attr, allow_dataclass=True):
"""Handles attribute access in auto_config'ed functions."""
if isinstance(value, config.Buildable):
fn_or_cls = value.__fn_or_cls__
if allow_dataclass and dataclasses.is_dataclass(fn_or_cls):
return getattr(value, attr)
raise ValueError(
f'Cannot load attribute {attr!r} on object of type {type(value)}'
' within auto_config, as this could lead to inconsistent behavior'
' between the Python and as_buildable code paths.'
)
return getattr(value, attr)
def auto_config_attr_save_handler(obj, attr, value, allow_dataclass=True):
"""Handles saving attributes in auto_config'ed functions."""
if isinstance(obj, config.Buildable):
fn_or_cls = obj.__fn_or_cls__
if allow_dataclass and dataclasses.is_dataclass(fn_or_cls):
setattr(obj, attr, value)
return
raise ValueError(
f'Cannot save attribute {attr!r} on object of type {type(obj)}'
' within auto_config, as this could lead to inconsistent behavior'
' between the Python and as_buildable code paths.'
)
def make_auto_config(fn):
if not isinstance(fn, (types.FunctionType, classmethod, staticmethod)):
raise ValueError(
'`auto_config` is only compatible with functions, '
f'`@classmethod`s, and `@staticmethod`s. Got {fn!r} '
f'with type {type(fn)!r}.'
)
if isinstance(fn, (classmethod, staticmethod)):
method_type = type(fn)
fn = fn.__func__
else:
method_type = None
source = _getsource(fn)
# Create the NodeTransformer that will transform the AST. The
# `_AutoConfigNodeTransformer` requires some additional information about
# the source to provide more informative error messages.
filename = inspect.getsourcefile(fn)
line_number = fn.__code__.co_firstlineno
node_transformer = _AutoConfigNodeTransformer(
source=source,
filename=filename, # pyrefly: ignore[bad-argument-type]
line_number=line_number,
allow_control_flow=experimental_allow_control_flow,
)
# Parse the AST, and modify it by intercepting all `Call`s with the
# `auto_config_call_handler`. Finally, ensure line numbers and code
# locations match up with the original function, to make errors
# interpretable.
node = ast.parse(source)
node = node_transformer.visit(node)
node = ast.fix_missing_locations(node)
node = ast.increment_lineno(node, line_number - 1)
assert isinstance(node, ast.Module)
# In order to allow us to use the original function closure below when
# constructing a new function object, we have to nest our modified AST
# within an outer `FunctionDef` that defines variables corresponding to the
# free variables in `fn`.
node = _wrap_ast_for_fn_with_closure_vars(node, fn)
# Compile the modified AST, and then find the function code object within
# the returned module-level code object.
code = compile(node, inspect.getsourcefile(fn), 'exec') # pyrefly: ignore[bad-argument-type]
code = _unwrap_code_for_fn(code, fn)
# Insert auto_config_attr_load_handler, auto_config_attr_save_handler,
# auto_config_call_handler into `fn.__closure__` at the index where
# _ATTR_LOAD_HANDLER_ID, _ATTR_SAVE_HANDLER_ID, _CALL_HANDLER_ID
# occur in the freevars. Both of them were added to freevars by
# _wrap_ast_for_fn_with_closure_vars.
closure = list(fn.__closure__ or ())
indexed_handlers = []
for handler_id, handler in (
(_ATTR_LOAD_HANDLER_ID, auto_config_attr_load_handler),
(_ATTR_SAVE_HANDLER_ID, auto_config_attr_save_handler),
):
if handler_id in code.co_freevars:
handler_idx = code.co_freevars.index(handler_id)
handler = _make_closure_cell(
functools.partial(
handler,
allow_dataclass=experimental_allow_dataclass_attribute_access,
)
)
indexed_handlers.append((handler_idx, handler))
if _CALL_HANDLER_ID in code.co_freevars:
handler_idx = code.co_freevars.index(_CALL_HANDLER_ID)
handler = _make_closure_cell(auto_config_call_handler)
indexed_handlers.append((handler_idx, handler))
# Insert handler from small index to ensure the content of closures will
# not be mismatched.
for handler_idx, handler in sorted(indexed_handlers):
closure.insert(handler_idx, handler)
closure = tuple(closure)
# Then, create a function from the compiled function code object, providing
# the globals and the original function's closure.
auto_config_fn = types.FunctionType(code, fn.__globals__, closure=closure)
auto_config_fn.__defaults__ = fn.__defaults__
auto_config_fn.__kwdefaults__ = fn.__kwdefaults__
# Finally we wrap the rewritten function to perform additional error
# checking and enforce that the output contains a `fdl.Buildable`.
if experimental_result_must_contain_buildable:
@functools.wraps(auto_config_fn)
def as_buildable(*args, **kwargs):
output = auto_config_fn(*args, **kwargs) # pylint: disable=not-callable
if not _contains_buildable(output):
raise TypeError(
f'The `auto_config` rewritten version of `{fn.__qualname__}` '
f'returned a `{type(output).__name__}`, which is not (or did not '
'contain) a `fdl.Buildable`. Please ensure this function returns '
'the result of an `auto_config`-eligible call expression, or a '
'supported container (list, tuple, dict) containing one.'
)
return output
else:
as_buildable = auto_config_fn
if method_type:
fn = method_type(fn)
as_buildable = method_type(as_buildable)
return AutoConfig(
fn, as_buildable, always_inline=experimental_always_inline # pyrefly: ignore[bad-argument-type]
)
# Decorator with empty parenthesis.
if fn is None:
return make_auto_config
else:
return make_auto_config(fn)
[docs]
def auto_unconfig(
fn=None, *, experimental_always_inline: Optional[bool] = None
) -> Any: # TODO(b/272377821): More precise return type.
"""Converts functions that create buildables directly into auto_config form.
While most of the time, the benefits of an auto_config representation of
object configuration and construction are valuable (e.g. static type
checking and tooling / refactoring support), sometimes it is more convenient
to manipulate buildable objects directly. ``auto_unconfig`` converts a
function that directly manipulates ``fdl.Buildable``'s (e.g. ``fdl.Config``'s)
into one that looks identically to an ``auto_config``'d function, and is fully
interoperable with the rest of the ``auto_config`` ecosystem.
Example::
@auto_unconfig
def make_experiment_trainer(name: str) -> fdl.Config[MyTrainer]
model = make_model.as_buildable(name)
select(model, DropOut).set(rate=0.42) # Full Fiddle API available!
dataset = make_dataset.as_buildable()
# Build fdl.Config's imperatively.
trainer_config = fdl.Config(MyTrainer)
trainer_config.model = model
trainer_config.train_dataset = dataset
trainer_config.skip_eval = True
return trainer_config # Return a `fdl.Buildable`
# Sample usage within an auto_config'd function.
@auto_config
def make_driver():
return TrainerDriver(
trainer=make_experiment_trainer('my_experiment'),
checkpointer=CustomCheckpointer())
# Sample usage outside of auto_config contexts.
def main():
# Use instantiated objects:
trainer = make_experiment_trainer('my_experiment')
for example in trainer.train_dataset:
print_prediction(trainer.model.predict(example))
# Or manipulate the configuration before calling `fdl.build`:
trainer_config = make_experiment_trainer.as_buildable('my_experiment')
trainer_config.skip_eval = False # Tweak configuration.
trainer2 = fdl.build(trainer_config)
run_trainer(trainer2)
Args:
fn: The function to convert.
experimental_always_inline: Whether the output of ``fn`` should always be
inlined into the caller's config. The default (if unspecified) is
``True``.
Returns:
An ``AutoConfig`` that corresponds to ``fn``.
"""
if experimental_always_inline is None:
experimental_always_inline = True
def make_unconfig(fn) -> AutoConfig:
@functools.wraps(fn)
def python_implementation(*args, **kwargs):
previous = building._state.in_build # pytype: disable=module-attr # pylint: disable=protected-access
building._state.in_build = False # pytype: disable=module-attr # pylint: disable=protected-access
try:
cfg = fn(*args, **kwargs)
return building.build(cfg)
finally:
building._state.in_build = previous # pytype: disable=module-attr # pylint: disable=protected-access
return AutoConfig(
func=python_implementation,
buildable_func=fn,
always_inline=experimental_always_inline,
)
# We use this pattern to support using the decorator with and without
# parenthesis.
if fn is None:
return make_unconfig
return make_unconfig(fn)
[docs]
def is_auto_config(function_object: Any) -> bool:
return isinstance(function_object, AutoConfig)
[docs]
def inline(buildable: config.Config):
"""Converts an ``auto_config``-based ``buildable`` into a DAG of Buildables.
``inline`` updates ``buildable`` in place to preserve aliasing within a larger
Fiddle configuration. If you would like to leave ``buildable`` unmodified,
make a shallow copy (``copy.copy``) before calling ``inline``.
Example::
# shared/input_pipelines.py
@auto_config(experimental_always_inline=False)
def make_input_pipeline(name: str, batch_size: int) -> InputPipeline:
file_path = '/base_path/'+name
augmentation = 'my_augmentation_routine'
# ...
return InputPipeline(file_path, augmentation, ...)
# config/main.py
@auto_config
def make_experiment():
data = make_input_pipeline('normal_dataset', batch_size)
model = ...
return Experiment(data, model)
# experiment_configuration.py
def make_experiment():
config = make_experiment.as_buildable()
config.data.name = 'advanced_dataset'
# config.data.augmentation = 'custom_augmentation' # Not configurable!!!
# return fdl.build(config) # Works like normal.
auto_config.inline(config.data)
print(config.data.file_path) # Prints: '/base_path/advanced_dataset'
config.data.augmentation = 'custom_augmentation' # Now exposed.
experiment = fdl.build(config) # Works like normal.
return experiment
Args:
buildable: The buildable of an ``auto_config``'d function to replace with
the root of a Fiddle DAG that corresponds to it.
Raises:
ValueError: If ``buildable`` is not a ``Config``, or if ``buildable``
doesn't correspond to an ``auto_config``'d function.
"""
if not isinstance(buildable, config.Config):
raise ValueError(
'Cannot `inline` non-Config buildables; '
f'{type(buildable)} is not compatible.'
)
if not is_auto_config(buildable.__fn_or_cls__):
raise ValueError(
'Cannot `inline` a non-auto_config function; '
f'`{buildable.__fn_or_cls__}` is not compatible.'
)
# Evaluate the `as_buildable` interpretation.
auto_config_fn = cast(AutoConfig, buildable.__fn_or_cls__)
tmp_config = auto_config_fn.as_buildable(**buildable.__arguments__) # pyrefly: ignore[bad-unpacking]
if not isinstance(tmp_config, config.Buildable):
raise ValueError(
'You cannot currently inline functions that do not return '
'`fdl.Buildable`s.'
)
mutate_buildable.move_buildable_internals(
source=tmp_config, destination=buildable
)
def _getsource(fn: Any) -> str:
"""Returns the source code for callable `fn`."""
if _is_lambda(fn):
return _getsource_for_lambda(fn)
else:
# Remove any indentation which would cause parsing issues when creating the
# AST (indentation generally is present for nested functions or class
# methods).
return textwrap.dedent(inspect.getsource(fn))
def _is_lambda(fn: Any) -> bool:
"""Returns True if `fn` is a lambda function."""
if not inspect.isfunction(fn):
return False
if not (hasattr(fn, '__name__') and hasattr(fn, '__code__')):
return False
return (fn.__name__ == '<lambda>') or (fn.__code__.co_name == '<lambda>')
class _LambdaFinder(cst.CSTVisitor):
"""CST Visitor that searches for the source code for a given lambda func."""
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,)
def __init__(self, lambda_fn):
super().__init__()
self.lambda_fn = lambda_fn
self.lineno = lambda_fn.__code__.co_firstlineno
self.candidates = []
def visit_Lambda(self, node) -> None:
loc = self.get_metadata(cst.metadata.PositionProvider, node)
if loc.start.line == self.lineno: # pyrefly: ignore[missing-attribute]
self.candidates.append(node)
def _getsource_for_lambda(fn: Callable[..., Any]) -> str:
"""Returns source code for the given lambda function."""
# Get the source for the module that defines `fn`.
module = inspect.getmodule(fn)
filename = inspect.getsourcefile(fn)
lines = linecache.getlines(filename, module.__dict__) # pyrefly: ignore[bad-argument-type]
source = ''.join(lines)
# Parse the CST for the module, and search for the lambda.
module_cst = cst.parse_module(source)
lambda_finder = _LambdaFinder(fn)
cst.metadata.MetadataWrapper(module_cst).visit(lambda_finder)
if len(lambda_finder.candidates) == 1:
lambda_node = lambda_finder.candidates[0]
return cst.Module(body=[lambda_node]).code
elif not lambda_finder.candidates:
raise ValueError(
'Fiddle auto_config was unable to find the source code for '
f'{fn}: could not find lambda on line {lambda_finder.lineno}.'
)
else:
# TODO(b/258671226): If desired, we could narrow down which lambda is
# used based on the signature (or even fancier things like the checking
# fn.__code__.co_names).
raise ValueError(
'Fiddle auto_config was unable to find the source code for '
f'{fn}: multiple lambdas found on line {lambda_finder.lineno}; '
'try moving each lambda to its own line.'
)
def with_buildable_func(
buildable_func: Callable[..., Any]
) -> Callable[..., Any]:
"""A decorator that adds an auto_config-only code path."""
def decorator(func):
return AutoConfig(
func=func, buildable_func=buildable_func, always_inline=True
)
return decorator