344 lines
13 KiB
Python
344 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import inspect
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints
|
|
|
|
from griffe import Docstring, DocstringSectionKind
|
|
from pydantic import BaseModel, Field, create_model
|
|
|
|
from .exceptions import UserError
|
|
from .run_context import RunContextWrapper
|
|
from .strict_schema import ensure_strict_json_schema
|
|
|
|
|
|
@dataclass
|
|
class FuncSchema:
|
|
"""
|
|
Captures the schema for a python function, in preparation for sending it to an LLM as a tool.
|
|
"""
|
|
|
|
name: str
|
|
"""The name of the function."""
|
|
description: str | None
|
|
"""The description of the function."""
|
|
params_pydantic_model: type[BaseModel]
|
|
"""A Pydantic model that represents the function's parameters."""
|
|
params_json_schema: dict[str, Any]
|
|
"""The JSON schema for the function's parameters, derived from the Pydantic model."""
|
|
signature: inspect.Signature
|
|
"""The signature of the function."""
|
|
takes_context: bool = False
|
|
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
|
|
strict_json_schema: bool = True
|
|
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
|
|
as it increases the likelihood of correct JSON input."""
|
|
|
|
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
|
|
"""
|
|
Converts validated data from the Pydantic model into (args, kwargs), suitable for calling
|
|
the original function.
|
|
"""
|
|
positional_args: list[Any] = []
|
|
keyword_args: dict[str, Any] = {}
|
|
seen_var_positional = False
|
|
|
|
# Use enumerate() so we can skip the first parameter if it's context.
|
|
for idx, (name, param) in enumerate(self.signature.parameters.items()):
|
|
# If the function takes a RunContextWrapper and this is the first parameter, skip it.
|
|
if self.takes_context and idx == 0:
|
|
continue
|
|
|
|
value = getattr(data, name, None)
|
|
if param.kind == param.VAR_POSITIONAL:
|
|
# e.g. *args: extend positional args and mark that *args is now seen
|
|
positional_args.extend(value or [])
|
|
seen_var_positional = True
|
|
elif param.kind == param.VAR_KEYWORD:
|
|
# e.g. **kwargs handling
|
|
keyword_args.update(value or {})
|
|
elif param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
|
# Before *args, add to positional args. After *args, add to keyword args.
|
|
if not seen_var_positional:
|
|
positional_args.append(value)
|
|
else:
|
|
keyword_args[name] = value
|
|
else:
|
|
# For KEYWORD_ONLY parameters, always use keyword args.
|
|
keyword_args[name] = value
|
|
return positional_args, keyword_args
|
|
|
|
|
|
@dataclass
|
|
class FuncDocumentation:
|
|
"""Contains metadata about a python function, extracted from its docstring."""
|
|
|
|
name: str
|
|
"""The name of the function, via `__name__`."""
|
|
description: str | None
|
|
"""The description of the function, derived from the docstring."""
|
|
param_descriptions: dict[str, str] | None
|
|
"""The parameter descriptions of the function, derived from the docstring."""
|
|
|
|
|
|
DocstringStyle = Literal["google", "numpy", "sphinx"]
|
|
|
|
|
|
# As of Feb 2025, the automatic style detection in griffe is an Insiders feature. This
|
|
# code approximates it.
|
|
def _detect_docstring_style(doc: str) -> DocstringStyle:
|
|
scores: dict[DocstringStyle, int] = {"sphinx": 0, "numpy": 0, "google": 0}
|
|
|
|
# Sphinx style detection: look for :param, :type, :return:, and :rtype:
|
|
sphinx_patterns = [r"^:param\s", r"^:type\s", r"^:return:", r"^:rtype:"]
|
|
for pattern in sphinx_patterns:
|
|
if re.search(pattern, doc, re.MULTILINE):
|
|
scores["sphinx"] += 1
|
|
|
|
# Numpy style detection: look for headers like 'Parameters', 'Returns', or 'Yields' followed by
|
|
# a dashed underline
|
|
numpy_patterns = [
|
|
r"^Parameters\s*\n\s*-{3,}",
|
|
r"^Returns\s*\n\s*-{3,}",
|
|
r"^Yields\s*\n\s*-{3,}",
|
|
]
|
|
for pattern in numpy_patterns:
|
|
if re.search(pattern, doc, re.MULTILINE):
|
|
scores["numpy"] += 1
|
|
|
|
# Google style detection: look for section headers with a trailing colon
|
|
google_patterns = [r"^(Args|Arguments):", r"^(Returns):", r"^(Raises):"]
|
|
for pattern in google_patterns:
|
|
if re.search(pattern, doc, re.MULTILINE):
|
|
scores["google"] += 1
|
|
|
|
max_score = max(scores.values())
|
|
if max_score == 0:
|
|
return "google"
|
|
|
|
# Priority order: sphinx > numpy > google in case of tie
|
|
styles: list[DocstringStyle] = ["sphinx", "numpy", "google"]
|
|
|
|
for style in styles:
|
|
if scores[style] == max_score:
|
|
return style
|
|
|
|
return "google"
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _suppress_griffe_logging():
|
|
# Supresses warnings about missing annotations for params
|
|
logger = logging.getLogger("griffe")
|
|
previous_level = logger.getEffectiveLevel()
|
|
logger.setLevel(logging.ERROR)
|
|
try:
|
|
yield
|
|
finally:
|
|
logger.setLevel(previous_level)
|
|
|
|
|
|
def generate_func_documentation(
|
|
func: Callable[..., Any], style: DocstringStyle | None = None
|
|
) -> FuncDocumentation:
|
|
"""
|
|
Extracts metadata from a function docstring, in preparation for sending it to an LLM as a tool.
|
|
|
|
Args:
|
|
func: The function to extract documentation from.
|
|
style: The style of the docstring to use for parsing. If not provided, we will attempt to
|
|
auto-detect the style.
|
|
|
|
Returns:
|
|
A FuncDocumentation object containing the function's name, description, and parameter
|
|
descriptions.
|
|
"""
|
|
name = func.__name__
|
|
doc = inspect.getdoc(func)
|
|
if not doc:
|
|
return FuncDocumentation(name=name, description=None, param_descriptions=None)
|
|
|
|
with _suppress_griffe_logging():
|
|
docstring = Docstring(doc, lineno=1, parser=style or _detect_docstring_style(doc))
|
|
parsed = docstring.parse()
|
|
|
|
description: str | None = next(
|
|
(section.value for section in parsed if section.kind == DocstringSectionKind.text), None
|
|
)
|
|
|
|
param_descriptions: dict[str, str] = {
|
|
param.name: param.description
|
|
for section in parsed
|
|
if section.kind == DocstringSectionKind.parameters
|
|
for param in section.value
|
|
}
|
|
|
|
return FuncDocumentation(
|
|
name=func.__name__,
|
|
description=description,
|
|
param_descriptions=param_descriptions or None,
|
|
)
|
|
|
|
|
|
def function_schema(
|
|
func: Callable[..., Any],
|
|
docstring_style: DocstringStyle | None = None,
|
|
name_override: str | None = None,
|
|
description_override: str | None = None,
|
|
use_docstring_info: bool = True,
|
|
strict_json_schema: bool = True,
|
|
) -> FuncSchema:
|
|
"""
|
|
Given a python function, extracts a `FuncSchema` from it, capturing the name, description,
|
|
parameter descriptions, and other metadata.
|
|
|
|
Args:
|
|
func: The function to extract the schema from.
|
|
docstring_style: The style of the docstring to use for parsing. If not provided, we will
|
|
attempt to auto-detect the style.
|
|
name_override: If provided, use this name instead of the function's `__name__`.
|
|
description_override: If provided, use this description instead of the one derived from the
|
|
docstring.
|
|
use_docstring_info: If True, uses the docstring to generate the description and parameter
|
|
descriptions.
|
|
strict_json_schema: Whether the JSON schema is in strict mode. If True, we'll ensure that
|
|
the schema adheres to the "strict" standard the OpenAI API expects. We **strongly**
|
|
recommend setting this to True, as it increases the likelihood of the LLM providing
|
|
correct JSON input.
|
|
|
|
Returns:
|
|
A `FuncSchema` object containing the function's name, description, parameter descriptions,
|
|
and other metadata.
|
|
"""
|
|
|
|
# 1. Grab docstring info
|
|
if use_docstring_info:
|
|
doc_info = generate_func_documentation(func, docstring_style)
|
|
param_descs = doc_info.param_descriptions or {}
|
|
else:
|
|
doc_info = None
|
|
param_descs = {}
|
|
|
|
func_name = name_override or doc_info.name if doc_info else func.__name__
|
|
|
|
# 2. Inspect function signature and get type hints
|
|
sig = inspect.signature(func)
|
|
type_hints = get_type_hints(func)
|
|
params = list(sig.parameters.items())
|
|
takes_context = False
|
|
filtered_params = []
|
|
|
|
if params:
|
|
first_name, first_param = params[0]
|
|
# Prefer the evaluated type hint if available
|
|
ann = type_hints.get(first_name, first_param.annotation)
|
|
if ann != inspect._empty:
|
|
origin = get_origin(ann) or ann
|
|
if origin is RunContextWrapper:
|
|
takes_context = True # Mark that the function takes context
|
|
else:
|
|
filtered_params.append((first_name, first_param))
|
|
else:
|
|
filtered_params.append((first_name, first_param))
|
|
|
|
# For parameters other than the first, raise error if any use RunContextWrapper.
|
|
for name, param in params[1:]:
|
|
ann = type_hints.get(name, param.annotation)
|
|
if ann != inspect._empty:
|
|
origin = get_origin(ann) or ann
|
|
if origin is RunContextWrapper:
|
|
raise UserError(
|
|
f"RunContextWrapper param found at non-first position in function"
|
|
f" {func.__name__}"
|
|
)
|
|
filtered_params.append((name, param))
|
|
|
|
# We will collect field definitions for create_model as a dict:
|
|
# field_name -> (type_annotation, default_value_or_Field(...))
|
|
fields: dict[str, Any] = {}
|
|
|
|
for name, param in filtered_params:
|
|
ann = type_hints.get(name, param.annotation)
|
|
default = param.default
|
|
|
|
# If there's no type hint, assume `Any`
|
|
if ann == inspect._empty:
|
|
ann = Any
|
|
|
|
# If a docstring param description exists, use it
|
|
field_description = param_descs.get(name, None)
|
|
|
|
# Handle different parameter kinds
|
|
if param.kind == param.VAR_POSITIONAL:
|
|
# e.g. *args: extend positional args
|
|
if get_origin(ann) is tuple:
|
|
# e.g. def foo(*args: tuple[int, ...]) -> treat as List[int]
|
|
args_of_tuple = get_args(ann)
|
|
if len(args_of_tuple) == 2 and args_of_tuple[1] is Ellipsis:
|
|
ann = list[args_of_tuple[0]] # type: ignore
|
|
else:
|
|
ann = list[Any]
|
|
else:
|
|
# If user wrote *args: int, treat as List[int]
|
|
ann = list[ann] # type: ignore
|
|
|
|
# Default factory to empty list
|
|
fields[name] = (
|
|
ann,
|
|
Field(default_factory=list, description=field_description), # type: ignore
|
|
)
|
|
|
|
elif param.kind == param.VAR_KEYWORD:
|
|
# **kwargs handling
|
|
if get_origin(ann) is dict:
|
|
# e.g. def foo(**kwargs: dict[str, int])
|
|
dict_args = get_args(ann)
|
|
if len(dict_args) == 2:
|
|
ann = dict[dict_args[0], dict_args[1]] # type: ignore
|
|
else:
|
|
ann = dict[str, Any]
|
|
else:
|
|
# e.g. def foo(**kwargs: int) -> Dict[str, int]
|
|
ann = dict[str, ann] # type: ignore
|
|
|
|
fields[name] = (
|
|
ann,
|
|
Field(default_factory=dict, description=field_description), # type: ignore
|
|
)
|
|
|
|
else:
|
|
# Normal parameter
|
|
if default == inspect._empty:
|
|
# Required field
|
|
fields[name] = (
|
|
ann,
|
|
Field(..., description=field_description),
|
|
)
|
|
else:
|
|
# Parameter with a default value
|
|
fields[name] = (
|
|
ann,
|
|
Field(default=default, description=field_description),
|
|
)
|
|
|
|
# 3. Dynamically build a Pydantic model
|
|
dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields)
|
|
|
|
# 4. Build JSON schema from that model
|
|
json_schema = dynamic_model.model_json_schema()
|
|
if strict_json_schema:
|
|
json_schema = ensure_strict_json_schema(json_schema)
|
|
|
|
# 5. Return as a FuncSchema dataclass
|
|
return FuncSchema(
|
|
name=func_name,
|
|
description=description_override or doc_info.description if doc_info else None,
|
|
params_pydantic_model=dynamic_model,
|
|
params_json_schema=json_schema,
|
|
signature=sig,
|
|
takes_context=takes_context,
|
|
strict_json_schema=strict_json_schema,
|
|
)
|