from typing import List from fastapi import APIRouter, HTTPException from loguru import logger from api.models import ( TransformationCreate, TransformationExecuteRequest, TransformationExecuteResponse, TransformationResponse, TransformationUpdate, ) from open_notebook.domain.models import Model from open_notebook.domain.transformation import Transformation from open_notebook.exceptions import DatabaseOperationError, InvalidInputError from open_notebook.graphs.transformation import graph as transformation_graph router = APIRouter() @router.get("/transformations", response_model=List[TransformationResponse]) async def get_transformations(): """Get all transformations.""" try: transformations = await Transformation.get_all(order_by="name asc") return [ TransformationResponse( id=transformation.id, name=transformation.name, title=transformation.title, description=transformation.description, prompt=transformation.prompt, apply_default=transformation.apply_default, created=str(transformation.created), updated=str(transformation.updated), ) for transformation in transformations ] except Exception as e: logger.error(f"Error fetching transformations: {str(e)}") raise HTTPException( status_code=500, detail=f"Error fetching transformations: {str(e)}" ) @router.post("/transformations", response_model=TransformationResponse) async def create_transformation(transformation_data: TransformationCreate): """Create a new transformation.""" try: new_transformation = Transformation( name=transformation_data.name, title=transformation_data.title, description=transformation_data.description, prompt=transformation_data.prompt, apply_default=transformation_data.apply_default, ) await new_transformation.save() return TransformationResponse( id=new_transformation.id, name=new_transformation.name, title=new_transformation.title, description=new_transformation.description, prompt=new_transformation.prompt, apply_default=new_transformation.apply_default, created=str(new_transformation.created), updated=str(new_transformation.updated), ) except InvalidInputError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Error creating transformation: {str(e)}") raise HTTPException( status_code=500, detail=f"Error creating transformation: {str(e)}" ) @router.get( "/transformations/{transformation_id}", response_model=TransformationResponse ) async def get_transformation(transformation_id: str): """Get a specific transformation by ID.""" try: transformation = await Transformation.get(transformation_id) if not transformation: raise HTTPException(status_code=404, detail="Transformation not found") return TransformationResponse( id=transformation.id, name=transformation.name, title=transformation.title, description=transformation.description, prompt=transformation.prompt, apply_default=transformation.apply_default, created=str(transformation.created), updated=str(transformation.updated), ) except HTTPException: raise except Exception as e: logger.error(f"Error fetching transformation {transformation_id}: {str(e)}") raise HTTPException( status_code=500, detail=f"Error fetching transformation: {str(e)}" ) @router.put( "/transformations/{transformation_id}", response_model=TransformationResponse ) async def update_transformation( transformation_id: str, transformation_update: TransformationUpdate ): """Update a transformation.""" try: transformation = await Transformation.get(transformation_id) if not transformation: raise HTTPException(status_code=404, detail="Transformation not found") # Update only provided fields if transformation_update.name is not None: transformation.name = transformation_update.name if transformation_update.title is not None: transformation.title = transformation_update.title if transformation_update.description is not None: transformation.description = transformation_update.description if transformation_update.prompt is not None: transformation.prompt = transformation_update.prompt if transformation_update.apply_default is not None: transformation.apply_default = transformation_update.apply_default await transformation.save() return TransformationResponse( id=transformation.id, name=transformation.name, title=transformation.title, description=transformation.description, prompt=transformation.prompt, apply_default=transformation.apply_default, created=str(transformation.created), updated=str(transformation.updated), ) except HTTPException: raise except InvalidInputError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Error updating transformation {transformation_id}: {str(e)}") raise HTTPException( status_code=500, detail=f"Error updating transformation: {str(e)}" ) @router.delete("/transformations/{transformation_id}") async def delete_transformation(transformation_id: str): """Delete a transformation.""" try: transformation = await Transformation.get(transformation_id) if not transformation: raise HTTPException(status_code=404, detail="Transformation not found") await transformation.delete() return {"message": "Transformation deleted successfully"} except HTTPException: raise except Exception as e: logger.error(f"Error deleting transformation {transformation_id}: {str(e)}") raise HTTPException( status_code=500, detail=f"Error deleting transformation: {str(e)}" ) @router.post("/transformations/execute", response_model=TransformationExecuteResponse) async def execute_transformation(execute_request: TransformationExecuteRequest): """Execute a transformation on input text.""" try: # Validate transformation exists transformation = await Transformation.get(execute_request.transformation_id) if not transformation: raise HTTPException(status_code=404, detail="Transformation not found") # Validate model exists model = await Model.get(execute_request.model_id) if not model: raise HTTPException(status_code=404, detail="Model not found") # Execute the transformation result = await transformation_graph.ainvoke( dict( input_text=execute_request.input_text, transformation=transformation, ), config=dict(configurable={"model_id": execute_request.model_id}), ) return TransformationExecuteResponse( output=result["output"], transformation_id=execute_request.transformation_id, model_id=execute_request.model_id, ) except HTTPException: raise except Exception as e: logger.error(f"Error executing transformation: {str(e)}") raise HTTPException( status_code=500, detail=f"Error executing transformation: {str(e)}" )