diff --git a/rag_tutorials/vision_rag_agent/README.md b/rag_tutorials/vision_rag_agent/README.md new file mode 100644 index 0000000..2b2f015 --- /dev/null +++ b/rag_tutorials/vision_rag_agent/README.md @@ -0,0 +1,91 @@ +# Vision RAG with Cohere Embed-4 đŸ–ŧī¸ + +A powerful visual Retrieval-Augmented Generation (RAG) system that utilizes Cohere's state-of-the-art Embed-4 model for multimodal embedding and Google's efficient Gemini 2.5 Flash model for answering questions about images. + +## Features + +- **Multimodal Search**: Leverages Cohere Embed-4 to find the most semantically relevant image for a given text question. +- **Visual Question Answering**: Employs Google Gemini 2.5 Flash to analyze the content of the retrieved image and generate accurate, context-aware answers. +- **Flexible Image Sources**: + - Use pre-loaded sample financial charts and infographics. + - Upload your own custom images (PNG, JPG, JPEG). +- **No OCR Required**: Directly processes complex images like charts, graphs, and infographics without needing separate text extraction steps. +- **Interactive UI**: Built with Streamlit for easy interaction, including image loading, question input, and result display. +- **Session Management**: Remembers loaded/uploaded images within a session. + +## Requirements + +- Python 3.8+ +- Cohere API key +- Google Gemini API key + +## How to Run + +Follow these steps to set up and run the application: + +1. **Clone and Navigate to Directory** : + ```bash + git clone https://github.com/Shubhamsaboo/awesome-llm-apps.git + cd awesome-llm-apps/rag_tutorials/vision_rag_agent + ``` + +2. **Install Dependencies**: + ```bash + pip install -r requirements.txt + ``` + +3. **Set up your API keys**: + - Get a Cohere API key from: [https://dashboard.cohere.com/api-keys](https://dashboard.cohere.com/api-keys) + - Get a Google API key from: [https://aistudio.google.com/app/apikey](https://aistudio.google.com/app/apikey) + +4. **Run the Streamlit app**: + ```bash + streamlit run vision_rag.py + ``` + +5. **Access the Web Interface**: + - Streamlit will provide a local URL (usually `http://localhost:8501`) in your terminal. + - Open this URL in your web browser. + +## How It Works + +The application follows a two-stage RAG process: + +1. **Retrieval**: + - When you load sample images or upload your own, each image is converted to a base64 string. + - Cohere's `embed-v4.0` model (with `input_type="search_document"`) is used to generate a dense vector embedding for each image. + - When you ask a question, the text query is embedded using the same `embed-v4.0` model (with `input_type="search_query"`). + - Cosine similarity is calculated between the question embedding and all image embeddings. + - The image with the highest similarity score is retrieved as the most relevant context. + +2. **Generation**: + - The original text question and the retrieved image are passed as input to the Google `gemini-2.5-flash-preview-04-17` model. + - Gemini analyzes the image content in the context of the question and generates a textual answer. + +## Usage + +1. Enter your Cohere and Google API keys in the sidebar. +2. Load images: + - Click **"Load Sample Images"** to download and process the built-in examples. + - *OR/AND* Use the **"Upload Your Images"** section to upload your own image files. +3. Once images are loaded and processed (embeddings generated), the **"Ask a Question"** section will be enabled. +4. Optionally, expand **"View Loaded Images"** to see thumbnails of all images currently in the session. +5. Type your question about the loaded images into the text input field. +6. Click **"Run Vision RAG"**. +7. View the results: + - The **Retrieved Image** deemed most relevant to your question. + - The **Generated Answer** from Gemini based on the image and question. + +## Use Cases + +- Analyze financial charts and extract key figures or trends. +- Answer specific questions about diagrams, flowcharts, or infographics. +- Extract information from tables or text within screenshots without explicit OCR. +- Build and query visual knowledge bases using natural language. +- Understand the content of various complex visual documents. + +## Note + +- Image processing (embedding) can take time, especially for many or large images. Sample images are cached after the first load. +- Ensure your API keys have the necessary permissions and quotas for the Cohere and Gemini models used. +- The quality of the answer depends on both the relevance of the retrieved image and the capability of the Gemini model to interpret the image based on the question. diff --git a/rag_tutorials/vision_rag_agent/requirements.txt b/rag_tutorials/vision_rag_agent/requirements.txt new file mode 100644 index 0000000..d22b9ee --- /dev/null +++ b/rag_tutorials/vision_rag_agent/requirements.txt @@ -0,0 +1,7 @@ +streamlit>=1.32.0 +cohere>=5.0.0 +google-generativeai>=0.3.0 +Pillow>=10.0.0 +requests>=2.31.0 +numpy>=1.24.0 +tqdm>=4.66.0 diff --git a/rag_tutorials/vision_rag_agent/vision_rag.py b/rag_tutorials/vision_rag_agent/vision_rag.py new file mode 100644 index 0000000..c5ca4b3 --- /dev/null +++ b/rag_tutorials/vision_rag_agent/vision_rag.py @@ -0,0 +1,453 @@ +import requests +import os +import io +import base64 +import PIL +from PIL import Image +import tqdm +import numpy as np +import streamlit as st +import cohere +from google import genai + +# --- Streamlit App Configuration --- +st.set_page_config(layout="wide", page_title="Vision RAG with Cohere Embed-4") +st.title("Vision RAG with Cohere Embed-4 đŸ–ŧī¸") + +# --- API Key Input --- +with st.sidebar: + st.header("🔑 API Keys") + cohere_api_key = st.text_input("Cohere API Key", type="password", key="cohere_key") + google_api_key = st.text_input("Google API Key (Gemini)", type="password", key="google_key") + "[Get a Cohere API key](https://dashboard.cohere.com/api-keys)" + "[Get a Google API key](https://aistudio.google.com/app/apikey)" + + st.markdown("---") + if not cohere_api_key: + st.warning("Please enter your Cohere API key to proceed.") + if not google_api_key: + st.warning("Please enter your Google API key to proceed.") + st.markdown("---") + + +# --- Initialize API Clients --- +co = None +genai_client = None +# Initialize Session State for embeddings and paths +if 'image_paths' not in st.session_state: + st.session_state.image_paths = [] +if 'doc_embeddings' not in st.session_state: + st.session_state.doc_embeddings = None + +if cohere_api_key and google_api_key: + try: + co = cohere.ClientV2(api_key=cohere_api_key) + st.sidebar.success("Cohere Client Initialized!") + except Exception as e: + st.sidebar.error(f"Cohere Initialization Failed: {e}") + + try: + genai_client = genai.Client(api_key=google_api_key) + st.sidebar.success("Gemini Client Initialized!") + except Exception as e: + st.sidebar.error(f"Gemini Initialization Failed: {e}") +else: + st.info("Enter your API keys in the sidebar to start.") + +# Information about the models +with st.expander("â„šī¸ About the models used"): + st.markdown(""" + ### Cohere Embed-4 + + Cohere's Embed-4 is a state-of-the-art multimodal embedding model designed for enterprise search and retrieval. + It enables: + + - **Multimodal search**: Search text and images together seamlessly + - **High accuracy**: State-of-the-art performance for retrieval tasks + - **Efficient embedding**: Process complex images like charts, graphs, and infographics + + The model processes images without requiring complex OCR pre-processing and maintains the connection between visual elements and text. + + ### Google Gemini 2.5 Flash + + Gemini 2.5 Flash is Google's efficient multimodal model that can process text and image inputs to generate high-quality responses. + It's designed for fast inference while maintaining high accuracy, making it ideal for real-time applications like this RAG system. + """) + +# --- Helper functions --- +# Some helper functions to resize images and to convert them to base64 format +max_pixels = 1568*1568 #Max resolution for images + +# Resize too large images +def resize_image(pil_image: PIL.Image.Image) -> None: + """Resizes the image in-place if it exceeds max_pixels.""" + org_width, org_height = pil_image.size + + # Resize image if too large + if org_width * org_height > max_pixels: + scale_factor = (max_pixels / (org_width * org_height)) ** 0.5 + new_width = int(org_width * scale_factor) + new_height = int(org_height * scale_factor) + pil_image.thumbnail((new_width, new_height)) + +# Convert images to a base64 string before sending it to the API +def base64_from_image(img_path: str) -> str: + """Converts an image file to a base64 encoded string.""" + pil_image = PIL.Image.open(img_path) + img_format = pil_image.format if pil_image.format else "PNG" + + resize_image(pil_image) + + with io.BytesIO() as img_buffer: + pil_image.save(img_buffer, format=img_format) + img_buffer.seek(0) + img_data = f"data:image/{img_format.lower()};base64,"+base64.b64encode(img_buffer.read()).decode("utf-8") + + return img_data + +# Convert PIL image to base64 string +def pil_to_base64(pil_image: PIL.Image.Image) -> str: + """Converts a PIL image to a base64 encoded string.""" + if pil_image.format is None: + img_format = "PNG" + else: + img_format = pil_image.format + + resize_image(pil_image) + + with io.BytesIO() as img_buffer: + pil_image.save(img_buffer, format=img_format) + img_buffer.seek(0) + img_data = f"data:image/{img_format.lower()};base64,"+base64.b64encode(img_buffer.read()).decode("utf-8") + + return img_data + +# Compute embedding for an image +@st.cache_data(ttl=3600, show_spinner=False) +def compute_image_embedding(base64_img: str, _cohere_client) -> np.ndarray: + """Computes an embedding for an image using Cohere's Embed-4 model.""" + try: + api_response = _cohere_client.embed( + model="embed-v4.0", + input_type="search_document", + embedding_types=["float"], + images=[base64_img], + ) + + if api_response.embeddings and api_response.embeddings.float: + return np.asarray(api_response.embeddings.float[0]) + else: + st.warning("Could not get embedding. API response might be empty.") + return None + except Exception as e: + st.error(f"Error computing embedding: {e}") + return None + +# Download and embed sample images +@st.cache_data(ttl=3600, show_spinner=False) +def download_and_embed_sample_images(_cohere_client): + """Downloads sample images and computes their embeddings using Cohere's Embed-4 model.""" + # Several images from https://www.appeconomyinsights.com/ + images = { + "tesla.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbef936e6-3efa-43b3-88d7-7ec620cdb33b_2744x1539.png", + "netflix.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F23bd84c9-5b62-4526-b467-3088e27e4193_2744x1539.png", + "nike.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fa5cd33ba-ae1a-42a8-a254-d85e690d9870_2741x1541.png", + "google.png": "https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F395dd3b9-b38e-4d1f-91bc-d37b642ee920_2741x1541.png", + "accenture.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F08b2227c-7dc8-49f7-b3c5-13cab5443ba6_2741x1541.png", + "tecent.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0ec8448c-c4d1-4aab-a8e9-2ddebe0c95fd_2741x1541.png" + } + + # Prepare folders + img_folder = "img" + os.makedirs(img_folder, exist_ok=True) + + img_paths = [] + doc_embeddings = [] + + # Wrap TQDM with st.spinner for better UI integration + with st.spinner("Downloading and embedding sample images..."): + pbar = tqdm.tqdm(images.items(), desc="Processing sample images") + for name, url in pbar: + img_path = os.path.join(img_folder, name) + # Don't re-append if already processed (useful if function called multiple times) + if img_path not in img_paths: + img_paths.append(img_path) + + # Download the image + if not os.path.exists(img_path): + try: + response = requests.get(url) + response.raise_for_status() + with open(img_path, "wb") as fOut: + fOut.write(response.content) + except requests.exceptions.RequestException as e: + st.error(f"Failed to download {name}: {e}") + # Optionally remove the path if download failed + img_paths.pop() + continue # Skip if download fails + + # Get embedding for the image if it exists and we haven't computed one yet + # Find index corresponding to this path + current_index = -1 + try: + current_index = img_paths.index(img_path) + except ValueError: + continue # Should not happen if append logic is correct + + # Check if embedding already exists for this index + if current_index >= len(doc_embeddings): + try: + # Ensure file exists before trying to embed + if os.path.exists(img_path): + base64_img = base64_from_image(img_path) + emb = compute_image_embedding(base64_img, _cohere_client=_cohere_client) + if emb is not None: + # Placeholder to ensure list length matches paths before vstack + while len(doc_embeddings) < current_index: + doc_embeddings.append(None) # Append placeholder if needed + doc_embeddings.append(emb) + else: + # If file doesn't exist (maybe failed download), add placeholder + while len(doc_embeddings) < current_index: + doc_embeddings.append(None) + doc_embeddings.append(None) + except Exception as e: + st.error(f"Failed to embed {name}: {e}") + # Add placeholder on error + while len(doc_embeddings) < current_index: + doc_embeddings.append(None) + doc_embeddings.append(None) + + # Filter out None embeddings and corresponding paths before stacking + filtered_paths = [path for i, path in enumerate(img_paths) if i < len(doc_embeddings) and doc_embeddings[i] is not None] + filtered_embeddings = [emb for emb in doc_embeddings if emb is not None] + + if filtered_embeddings: + doc_embeddings_array = np.vstack(filtered_embeddings) + return filtered_paths, doc_embeddings_array + + return [], None + +# Search function +def search(question: str, co_client: cohere.Client, embeddings: np.ndarray, image_paths: list[str], max_img_size: int = 800) -> str | None: + """Finds the most relevant image path for a given question.""" + if not co_client or embeddings is None or embeddings.size == 0 or not image_paths: + st.warning("Search prerequisites not met (client, embeddings, or paths missing/empty).") + return None + if embeddings.shape[0] != len(image_paths): + st.error(f"Mismatch between embeddings count ({embeddings.shape[0]}) and image paths count ({len(image_paths)}). Cannot perform search.") + return None + + try: + # Compute the embedding for the query + api_response = co_client.embed( + model="embed-v4.0", + input_type="search_query", + embedding_types=["float"], + texts=[question], + ) + + if not api_response.embeddings or not api_response.embeddings.float: + st.error("Failed to get query embedding.") + return None + + query_emb = np.asarray(api_response.embeddings.float[0]) + + # Ensure query embedding has the correct shape for dot product + if query_emb.shape[0] != embeddings.shape[1]: + st.error(f"Query embedding dimension ({query_emb.shape[0]}) does not match document embedding dimension ({embeddings.shape[1]}).") + return None + + # Compute cosine similarities + cos_sim_scores = np.dot(query_emb, embeddings.T) + + # Get the most relevant image + top_idx = np.argmax(cos_sim_scores) + hit_img_path = image_paths[top_idx] + print(f"Question: {question}") # Keep for debugging + print(f"Most relevant image: {hit_img_path}") # Keep for debugging + + return hit_img_path + except Exception as e: + st.error(f"Error during search: {e}") + return None + +# Answer function +def answer(question: str, img_path: str, gemini_client) -> str: + """Answers the question based on the provided image using Gemini.""" + if not gemini_client or not img_path or not os.path.exists(img_path): + missing = [] + if not gemini_client: missing.append("Gemini client") + if not img_path: missing.append("Image path") + elif not os.path.exists(img_path): missing.append(f"Image file at {img_path}") + return f"Answering prerequisites not met ({', '.join(missing)} missing or invalid)." + try: + img = PIL.Image.open(img_path) + prompt = [f"""Answer the question based on the following image. Be as elaborate as possible giving extra relevant information. +Don't use markdown formatting in the response. +Please provide enough context for your answer. + +Question: {question}""", img] + + response = gemini_client.models.generate_content( + model="gemini-2.5-flash-preview-04-17", + contents=prompt + ) + + llm_answer = response.text + print("LLM Answer:", llm_answer) # Keep for debugging + return llm_answer + except Exception as e: + st.error(f"Error during answer generation: {e}") + return f"Failed to generate answer: {e}" + +# --- Main UI Setup --- +st.subheader("📊 Load Sample Images") +if cohere_api_key and co: + # If button clicked, load sample images into session state + if st.button("Load Sample Images", key="load_sample_button"): + sample_img_paths, sample_doc_embeddings = download_and_embed_sample_images(_cohere_client=co) + if sample_img_paths and sample_doc_embeddings is not None: + # Append sample images to session state (avoid duplicates if clicked again) + current_paths = set(st.session_state.image_paths) + new_paths = [p for p in sample_img_paths if p not in current_paths] + + if new_paths: + new_indices = [i for i, p in enumerate(sample_img_paths) if p in new_paths] + st.session_state.image_paths.extend(new_paths) + new_embeddings_to_add = sample_doc_embeddings[[idx for idx, p in enumerate(sample_img_paths) if p in new_paths]] + + if st.session_state.doc_embeddings is None or st.session_state.doc_embeddings.size == 0: + st.session_state.doc_embeddings = new_embeddings_to_add + else: + st.session_state.doc_embeddings = np.vstack((st.session_state.doc_embeddings, new_embeddings_to_add)) + st.success(f"Loaded {len(new_paths)} sample images.") + else: + st.info("Sample images already loaded.") + else: + st.error("Failed to load sample images. Check console for errors.") +else: + st.warning("Enter API keys to enable loading sample images.") + +st.markdown("--- ") +# --- File Uploader (Main UI) --- +st.subheader("📤 Upload Your Images") +st.info("Or, upload your own images. The RAG process will search across all loaded sample images and uploaded images.") + +# File uploader +uploaded_files = st.file_uploader("Upload images", type=["png", "jpg", "jpeg"], + accept_multiple_files=True, key="image_uploader", + label_visibility="collapsed") + +# Process uploaded images +if uploaded_files and co: + st.write(f"Processing {len(uploaded_files)} uploaded images...") + progress_bar = st.progress(0) + + # Create a temporary directory for uploaded images + upload_folder = "uploaded_img" + os.makedirs(upload_folder, exist_ok=True) + + newly_uploaded_paths = [] + newly_uploaded_embeddings = [] + + for i, uploaded_file in enumerate(uploaded_files): + # Check if already processed this session (simple name check) + img_path = os.path.join(upload_folder, uploaded_file.name) + if img_path not in st.session_state.image_paths: + try: + # Save the uploaded file + with open(img_path, "wb") as f: + f.write(uploaded_file.getbuffer()) + + # Get embedding + base64_img = base64_from_image(img_path) + emb = compute_image_embedding(base64_img, _cohere_client=co) + + if emb is not None: + newly_uploaded_paths.append(img_path) + newly_uploaded_embeddings.append(emb) + + except Exception as e: + st.error(f"Error processing {uploaded_file.name}: {e}") + # Update progress regardless of processing status for user feedback + progress_bar.progress((i + 1) / len(uploaded_files)) + + # Add newly processed files to session state + if newly_uploaded_paths: + st.session_state.image_paths.extend(newly_uploaded_paths) + if newly_uploaded_embeddings: + new_embeddings_array = np.vstack(newly_uploaded_embeddings) + if st.session_state.doc_embeddings is None or st.session_state.doc_embeddings.size == 0: + st.session_state.doc_embeddings = new_embeddings_array + else: + st.session_state.doc_embeddings = np.vstack((st.session_state.doc_embeddings, new_embeddings_array)) + st.success(f"Successfully processed and added {len(newly_uploaded_paths)} new images.") + else: + st.warning("Failed to generate embeddings for newly uploaded images.") + elif uploaded_files: # If files were selected but none were new + st.info("Selected images already seem to be processed.") + +# --- Vision RAG Section (Main UI) --- +st.markdown("---") +st.subheader("❓ Ask a Question") + +if not st.session_state.image_paths: + st.warning("Please load sample images or upload your own images first.") +else: + st.info(f"Ready to answer questions about {len(st.session_state.image_paths)} images.") + + # Display thumbnails of all loaded images (optional) + with st.expander("View Loaded Images", expanded=False): + if st.session_state.image_paths: + num_images_to_show = len(st.session_state.image_paths) + cols = st.columns(5) # Show 5 thumbnails per row + for i in range(num_images_to_show): + with cols[i % 5]: + # Add try-except for missing files during display + try: + st.image(st.session_state.image_paths[i], width=100, caption=os.path.basename(st.session_state.image_paths[i])) + except FileNotFoundError: + st.error(f"Missing: {os.path.basename(st.session_state.image_paths[i])}") + else: + st.write("No images loaded yet.") + +question = st.text_input("Ask a question about the loaded images:", + key="main_question_input", + placeholder="E.g., What is Nike's net profit?", + disabled=not st.session_state.image_paths) + +run_button = st.button("Run Vision RAG", key="main_run_button", + disabled=not (cohere_api_key and google_api_key and question and st.session_state.image_paths and st.session_state.doc_embeddings is not None and st.session_state.doc_embeddings.size > 0)) + +# Output Area +st.markdown("### Results") +retrieved_image_placeholder = st.empty() +answer_placeholder = st.empty() + +# Run search and answer logic +if run_button: + if co and genai_client and st.session_state.doc_embeddings is not None and len(st.session_state.doc_embeddings) > 0: + with st.spinner("Finding relevant image..."): + # Ensure embeddings and paths match before search + if len(st.session_state.image_paths) != st.session_state.doc_embeddings.shape[0]: + st.error("Error: Mismatch between number of images and embeddings. Cannot proceed.") + else: + top_image_path = search(question, co, st.session_state.doc_embeddings, st.session_state.image_paths) + + if top_image_path: + retrieved_image_placeholder.image(top_image_path, caption=f"Retrieved image for: '{question}'", use_container_width=True) + + with st.spinner("Generating answer..."): + final_answer = answer(question, top_image_path, genai_client) + answer_placeholder.markdown(f"**Answer:**\n{final_answer}") + else: + retrieved_image_placeholder.warning("Could not find a relevant image for your question.") + answer_placeholder.text("") # Clear answer placeholder + else: + # This case should ideally be prevented by the disabled state of the button + st.error("Cannot run RAG. Check API clients and ensure images are loaded with embeddings.") + +# Footer +st.markdown("---") +st.caption("Vision RAG with Cohere Embed-4 | Built with Streamlit, Cohere Embed-4, and Google Gemini 2.5 Flash") \ No newline at end of file