Introduction
In the age of rapid AI advancement, software engineers are frequently tasked with building sophisticated chatbots. This article provides a practical guide to developing a real-time, response-streaming chatbot from scratch using transformer models, autoencoding (BERT family), and autoregressive (GPT family) techniques. We’ll learn how to build a chatbot capable of contextual conversations, processing user text and image/PDF uploads for enhanced understanding. We’ll utilize Python (aiohttp) and React (Next.js) to deliver a seamless, interactive experience. Let’s dive in!
Prerequisite
No prior expertise is strictly necessary, but some foundation in AI, software engineering, Python coding, and frontend development with a framework like React, Svelte, Vue, or Angular will be beneficial. If you’re new to aiohttp
or Svelte
, we highly suggest exploring this previous article series as a helpful primer. Don’t worry if you’re not an expert (I ain’t).
To set up your project, first create a root folder genai-chatbot
and then navigate into the backend
subdirectory:
<span>mkdir </span>genai-chatbot <span>&&</span> <span>cd </span>genai-chatbot <span>&&</span> <span>mkdir </span>backend <span>&&</span> <span>cd </span>backend<span>mkdir </span>genai-chatbot <span>&&</span> <span>cd </span>genai-chatbot <span>&&</span> <span>mkdir </span>backend <span>&&</span> <span>cd </span>backendmkdir genai-chatbot && cd genai-chatbot && mkdir backend && cd backend
Enter fullscreen mode Exit fullscreen mode
In the backend
directory, create a Python virtual environment and install dependencies from requirements.txt
:
python3 <span>-m</span> venv virtualenv <span>&&</span> <span>source </span>virtualenv/bin/activate <span>&&</span> pip <span>install</span> <span>-r</span> requirements.txtpython3 <span>-m</span> venv virtualenv <span>&&</span> <span>source </span>virtualenv/bin/activate <span>&&</span> pip <span>install</span> <span>-r</span> requirements.txtpython3 -m venv virtualenv && source virtualenv/bin/activate && pip install -r requirements.txt
Enter fullscreen mode Exit fullscreen mode
Ensure you create a requirements.txt
file in the backend
directory with the following content:
aiohttpaiodnstransformerstorchtorch-visionpymupdfpytesseractaiohttp aiodns transformers torch torch-vision pymupdf pytesseractaiohttp aiodns transformers torch torch-vision pymupdf pytesseract
Enter fullscreen mode Exit fullscreen mode
For the frontend, navigate to the root directory (genai-chatbot
) and use create-next-app
to generate a React frontend with TypeScript:
<span>cd</span> ..npx create-next-app@latest react-frontend <span>--typescript</span> <span>--eslint</span> <span>--app</span><span>cd</span> .. npx create-next-app@latest react-frontend <span>--typescript</span> <span>--eslint</span> <span>--app</span>cd .. npx create-next-app@latest react-frontend --typescript --eslint --app
Enter fullscreen mode Exit fullscreen mode
Follow the prompts from create-next-app
.
Tailwind CSS Setup Recommendation:
Please note that create-next-app
might include an older version of Tailwind CSS by default. For the best experience, avoiding the default Tailwind CSS setup during the create-next-app
configuration is recommended. Instead, follow the official Tailwind CSS Next.js installation guide to install the latest version, ensuring you install Tailwind CSS and related packages as devDependencies
.
Source code
Sirneij / genai-chatbot
A streaming chatbot powered by Large Language Models
GenAI ChatBot
A full-stack chat application featuring autoregressive and masked language models with streaming responses and context-aware interactions.
Overview
- Frontend: Next.js-based chat interface with real-time WebSocket streaming
- Backend: Python (aiohttp) server handling LLMs and file processing
Quick Start
- Clone the repository and enter the directory:
git clone https://github.com/Sirneij/genai-chatbot.git
cd genai-chatbot
Enter fullscreen mode Exit fullscreen mode
- Start the backend server (in a new terminal):
cd backend python -m venv virtualenv source virtualenv/bin/activate # On Windows, use: virtualenv\Scripts\activate pip install -r requirements.txt adev runserver main.py # For development with auto-reload # Or use: python main.py # For production
Enter fullscreen mode Exit fullscreen mode
- Start the frontend app (in another terminal):
cd react-frontend
npm install
npm run dev
Enter fullscreen mode Exit fullscreen mode
The application will be available at http://localhost:3000
Features
- Real-time streaming responses
- Support for both autoregressive and masked LLMs
- File upload support (PDFs/Images)
- Rich text rendering:
- Markdown support
- Code syntax highlighting
- KaTeX math expressions
- Unicode emoji support
- Dark mode and responsive design
Project Structure
Implementation
Step 1: Theory – Autoregressive vs Autoencoding models
Two primary approaches emerge within the transformer model architecture: autoregressive and autoencoding. Autoregressive models, like those using the decoder component of a transformer, predict the next word (token) in a sequence based solely on the preceding words (a strictly left-to-right approach). This sequential generation method is ideal for tasks like text generation and response streaming. In contrast, autoencoding models (also known as masked language models) leverage the encoder part of the transformer. They operate bidirectionally, considering both preceding and subsequent words in a sentence to understand the context and produce representations. Refer to Hugging Face’s model summary for a detailed comparison. Pretraining methodology is a key differentiator: autoregressive models are pretrained to predict the next token, making them naturally suited for text generation, while autoencoding models excel at contextual understanding, making them strong for tasks like question answering based on documents.
In this project, we will practically explore the distinct text generation capabilities of both model types in real-time. We will use microsoft/Phi-3-mini-4k-instruct
(autoregressive) and deepset/roberta-base-squad2
(autoencoding) for demonstration, acknowledging resource limitations and focusing on showcasing the core differences.
Step 2: Python asynchronous backend with aiohttp
We’ll start with the project structure from Building and Rigorously Testing a WebSocket and HTTP Server. Populate src/app/__init__.py
with the following code, noting the key adaptations for our chatbot:
<span>import</span> <span>json</span><span>from</span> <span>weakref</span> <span>import</span> <span>WeakSet</span><span>from</span> <span>aiohttp</span> <span>import</span> <span>WSCloseCode</span><span>,</span> <span>web</span><span>from</span> <span>aiohttp.web</span> <span>import</span> <span>Request</span><span>,</span> <span>Response</span><span>,</span> <span>WebSocketResponse</span><span>from</span> <span>src.utils.auto_chat_engine</span> <span>import</span> <span>(</span><span>cleanup_auto_model</span><span>,</span><span>gpt_question_and_answer</span><span>,</span><span>prepare_auto_tokenizer_and_model</span><span>,</span><span>)</span><span>from</span> <span>src.utils.chat_engine</span> <span>import</span> <span>(</span><span>prepare_qa_tokenizer_and_model</span><span>,</span><span>squad_question_answering</span><span>,</span><span>)</span><span>from</span> <span>src.utils.extract</span> <span>import</span> <span>extract_text_from_file</span><span>from</span> <span>src.utils.settings</span> <span>import</span> <span>base_settings</span><span>WEBSOCKETS</span> <span>=</span> <span>web</span><span>.</span><span>AppKey</span><span>(</span><span>'</span><span>websockets</span><span>'</span><span>,</span> <span>WeakSet</span><span>[</span><span>WebSocketResponse</span><span>])</span><span>async</span> <span>def</span> <span>start_background_tasks</span><span>(</span><span>app</span><span>:</span> <span>web</span><span>.</span><span>Application</span><span>)</span> <span>-></span> <span>None</span><span>:</span><span>"""</span><span>Initialize application background tasks.</span><span>"""</span><span>app</span><span>[</span><span>WEBSOCKETS</span><span>]</span> <span>=</span> <span>WeakSet</span><span>()</span><span>await</span> <span>prepare_auto_tokenizer_and_model</span><span>(</span><span>base_settings</span><span>.</span><span>MODEL_NAME</span><span>)</span><span>await</span> <span>prepare_qa_tokenizer_and_model</span><span>(</span><span>base_settings</span><span>.</span><span>QA_MODEL_NAME</span><span>)</span><span>async</span> <span>def</span> <span>cleanup_app</span><span>(</span><span>app</span><span>:</span> <span>web</span><span>.</span><span>Application</span><span>)</span> <span>-></span> <span>None</span><span>:</span><span>"""</span><span>Cleanup WebSocket connections on shutdown.</span><span>"""</span><span># Cleanup models </span> <span>await</span> <span>cleanup_auto_model</span><span>()</span><span># Close all WebSocket connections </span> <span>for</span> <span>websocket</span> <span>in</span> <span>set</span><span>(</span><span>app</span><span>[</span><span>WEBSOCKETS</span><span>]):</span> <span># type: ignore </span> <span>await</span> <span>websocket</span><span>.</span><span>close</span><span>(</span><span>code</span><span>=</span><span>WSCloseCode</span><span>.</span><span>GOING_AWAY</span><span>,</span> <span>message</span><span>=</span><span>b</span><span>'</span><span>Server shutdown</span><span>'</span><span>)</span><span>async</span> <span>def</span> <span>extract_text</span><span>(</span><span>request</span><span>:</span> <span>Request</span><span>)</span> <span>-></span> <span>Response</span><span>:</span><span>"""</span><span>Extract text from PDF and image files.</span><span>"""</span><span>data</span> <span>=</span> <span>await</span> <span>request</span><span>.</span><span>post</span><span>()</span><span>files</span> <span>=</span> <span>data</span><span>.</span><span>getall</span><span>(</span><span>'</span><span>file</span><span>'</span><span>)</span><span>if</span> <span>not</span> <span>files</span><span>:</span><span>return</span> <span>web</span><span>.</span><span>json_response</span><span>({</span><span>'</span><span>error</span><span>'</span><span>:</span> <span>'</span><span>No files provided</span><span>'</span><span>},</span> <span>status</span><span>=</span><span>400</span><span>)</span><span>extracted_text</span> <span>=</span> <span>[]</span><span>for</span> <span>file</span> <span>in</span> <span>files</span><span>:</span><span>if</span> <span>file</span><span>.</span><span>content_type</span> <span>not</span> <span>in</span> <span>[</span><span>'</span><span>application/pdf</span><span>'</span><span>,</span> <span>'</span><span>image/jpeg</span><span>'</span><span>,</span> <span>'</span><span>image/png</span><span>'</span><span>]:</span><span>return</span> <span>web</span><span>.</span><span>json_response</span><span>({</span><span>'</span><span>error</span><span>'</span><span>:</span> <span>'</span><span>Invalid file type</span><span>'</span><span>},</span> <span>status</span><span>=</span><span>400</span><span>)</span><span>text</span> <span>=</span> <span>await</span> <span>extract_text_from_file</span><span>(</span><span>file</span><span>.</span><span>file</span><span>,</span> <span>file</span><span>.</span><span>content_type</span><span>)</span><span>extracted_text</span><span>.</span><span>append</span><span>(</span><span>text</span><span>)</span><span>base_settings</span><span>.</span><span>context</span> <span>=</span> <span>'</span><span>\n</span><span>'</span><span>.</span><span>join</span><span>(</span><span>extracted_text</span><span>)</span><span>return</span> <span>web</span><span>.</span><span>json_response</span><span>({</span><span>'</span><span>success</span><span>'</span><span>:</span> <span>'</span><span>Text extracted successfully</span><span>'</span><span>})</span><span>async</span> <span>def</span> <span>chat_handler</span><span>(</span><span>request</span><span>:</span> <span>Request</span><span>)</span> <span>-></span> <span>Response</span><span>:</span><span>"""</span><span>Handle WebSocket connections.</span><span>"""</span><span>ws</span> <span>=</span> <span>WebSocketResponse</span><span>()</span><span>await</span> <span>ws</span><span>.</span><span>prepare</span><span>(</span><span>request</span><span>)</span><span>request</span><span>.</span><span>app</span><span>[</span><span>WEBSOCKETS</span><span>].</span><span>add</span><span>(</span><span>ws</span><span>)</span><span>async</span> <span>for</span> <span>msg</span> <span>in</span> <span>ws</span><span>:</span><span>if</span> <span>msg</span><span>.</span><span>type</span> <span>==</span> <span>web</span><span>.</span><span>WSMsgType</span><span>.</span><span>TEXT</span><span>:</span><span>try</span><span>:</span><span>data</span> <span>=</span> <span>json</span><span>.</span><span>loads</span><span>(</span><span>msg</span><span>.</span><span>data</span><span>)</span><span>question_type</span> <span>=</span> <span>data</span><span>.</span><span>get</span><span>(</span><span>'</span><span>type</span><span>'</span><span>)</span><span>question</span> <span>=</span> <span>data</span><span>.</span><span>get</span><span>(</span><span>'</span><span>question</span><span>'</span><span>,</span> <span>''</span><span>).</span><span>strip</span><span>()</span><span>if</span> <span>not</span> <span>question</span><span>:</span><span>await</span> <span>ws</span><span>.</span><span>send_str</span><span>(</span><span>'</span><span>Error: No question provided.</span><span>'</span><span>)</span><span>continue</span><span>if</span> <span>question_type</span> <span>==</span> <span>'</span><span>auto</span><span>'</span><span>:</span><span># Stream response token by token. </span> <span>async</span> <span>for</span> <span>token</span> <span>in</span> <span>gpt_question_and_answer</span><span>(</span><span>question</span><span>):</span><span>await</span> <span>ws</span><span>.</span><span>send_json</span><span>({</span><span>'</span><span>answer</span><span>'</span><span>:</span> <span>token</span><span>})</span><span>elif</span> <span>question_type</span> <span>==</span> <span>'</span><span>masked</span><span>'</span><span>:</span><span># Use squad question answering (non-streamed). </span> <span>answer</span> <span>=</span> <span>await</span> <span>squad_question_answering</span><span>(</span><span>question</span><span>)</span><span>await</span> <span>ws</span><span>.</span><span>send_json</span><span>({</span><span>'</span><span>answer</span><span>'</span><span>:</span> <span>answer</span><span>})</span><span>else</span><span>:</span><span>await</span> <span>ws</span><span>.</span><span>send_str</span><span>(</span><span>'</span><span>Error: Unknown question type.</span><span>'</span><span>)</span><span>except</span> <span>Exception</span> <span>as</span> <span>e</span><span>:</span><span>await</span> <span>ws</span><span>.</span><span>send_str</span><span>(</span><span>f</span><span>'</span><span>Error processing message: </span><span>{</span><span>str</span><span>(</span><span>e</span><span>)</span><span>}</span><span>'</span><span>)</span><span>elif</span> <span>msg</span><span>.</span><span>type</span> <span>==</span> <span>web</span><span>.</span><span>WSMsgType</span><span>.</span><span>ERROR</span><span>:</span><span>request</span><span>.</span><span>app</span><span>[</span><span>WEBSOCKETS</span><span>].</span><span>remove</span><span>(</span><span>ws</span><span>)</span><span>break</span><span>request</span><span>.</span><span>app</span><span>[</span><span>WEBSOCKETS</span><span>].</span><span>remove</span><span>(</span><span>ws</span><span>)</span><span>return</span> <span>ws</span><span>def</span> <span>init_app</span><span>()</span> <span>-></span> <span>web</span><span>.</span><span>Application</span><span>:</span><span>"""</span><span>Initialize the application.</span><span>"""</span><span>app</span> <span>=</span> <span>web</span><span>.</span><span>Application</span><span>()</span><span># Add routes </span> <span>app</span><span>.</span><span>router</span><span>.</span><span>add_post</span><span>(</span><span>'</span><span>/api/extract</span><span>'</span><span>,</span> <span>extract_text</span><span>)</span><span>app</span><span>.</span><span>router</span><span>.</span><span>add_get</span><span>(</span><span>'</span><span>/chat</span><span>'</span><span>,</span> <span>chat_handler</span><span>)</span><span># Add startup/cleanup handlers </span> <span>app</span><span>.</span><span>on_startup</span><span>.</span><span>append</span><span>(</span><span>start_background_tasks</span><span>)</span><span>app</span><span>.</span><span>on_shutdown</span><span>.</span><span>append</span><span>(</span><span>cleanup_app</span><span>)</span><span>return</span> <span>app</span><span>import</span> <span>json</span> <span>from</span> <span>weakref</span> <span>import</span> <span>WeakSet</span> <span>from</span> <span>aiohttp</span> <span>import</span> <span>WSCloseCode</span><span>,</span> <span>web</span> <span>from</span> <span>aiohttp.web</span> <span>import</span> <span>Request</span><span>,</span> <span>Response</span><span>,</span> <span>WebSocketResponse</span> <span>from</span> <span>src.utils.auto_chat_engine</span> <span>import</span> <span>(</span> <span>cleanup_auto_model</span><span>,</span> <span>gpt_question_and_answer</span><span>,</span> <span>prepare_auto_tokenizer_and_model</span><span>,</span> <span>)</span> <span>from</span> <span>src.utils.chat_engine</span> <span>import</span> <span>(</span> <span>prepare_qa_tokenizer_and_model</span><span>,</span> <span>squad_question_answering</span><span>,</span> <span>)</span> <span>from</span> <span>src.utils.extract</span> <span>import</span> <span>extract_text_from_file</span> <span>from</span> <span>src.utils.settings</span> <span>import</span> <span>base_settings</span> <span>WEBSOCKETS</span> <span>=</span> <span>web</span><span>.</span><span>AppKey</span><span>(</span><span>'</span><span>websockets</span><span>'</span><span>,</span> <span>WeakSet</span><span>[</span><span>WebSocketResponse</span><span>])</span> <span>async</span> <span>def</span> <span>start_background_tasks</span><span>(</span><span>app</span><span>:</span> <span>web</span><span>.</span><span>Application</span><span>)</span> <span>-></span> <span>None</span><span>:</span> <span>"""</span><span>Initialize application background tasks.</span><span>"""</span> <span>app</span><span>[</span><span>WEBSOCKETS</span><span>]</span> <span>=</span> <span>WeakSet</span><span>()</span> <span>await</span> <span>prepare_auto_tokenizer_and_model</span><span>(</span><span>base_settings</span><span>.</span><span>MODEL_NAME</span><span>)</span> <span>await</span> <span>prepare_qa_tokenizer_and_model</span><span>(</span><span>base_settings</span><span>.</span><span>QA_MODEL_NAME</span><span>)</span> <span>async</span> <span>def</span> <span>cleanup_app</span><span>(</span><span>app</span><span>:</span> <span>web</span><span>.</span><span>Application</span><span>)</span> <span>-></span> <span>None</span><span>:</span> <span>"""</span><span>Cleanup WebSocket connections on shutdown.</span><span>"""</span> <span># Cleanup models </span> <span>await</span> <span>cleanup_auto_model</span><span>()</span> <span># Close all WebSocket connections </span> <span>for</span> <span>websocket</span> <span>in</span> <span>set</span><span>(</span><span>app</span><span>[</span><span>WEBSOCKETS</span><span>]):</span> <span># type: ignore </span> <span>await</span> <span>websocket</span><span>.</span><span>close</span><span>(</span><span>code</span><span>=</span><span>WSCloseCode</span><span>.</span><span>GOING_AWAY</span><span>,</span> <span>message</span><span>=</span><span>b</span><span>'</span><span>Server shutdown</span><span>'</span><span>)</span> <span>async</span> <span>def</span> <span>extract_text</span><span>(</span><span>request</span><span>:</span> <span>Request</span><span>)</span> <span>-></span> <span>Response</span><span>:</span> <span>"""</span><span>Extract text from PDF and image files.</span><span>"""</span> <span>data</span> <span>=</span> <span>await</span> <span>request</span><span>.</span><span>post</span><span>()</span> <span>files</span> <span>=</span> <span>data</span><span>.</span><span>getall</span><span>(</span><span>'</span><span>file</span><span>'</span><span>)</span> <span>if</span> <span>not</span> <span>files</span><span>:</span> <span>return</span> <span>web</span><span>.</span><span>json_response</span><span>({</span><span>'</span><span>error</span><span>'</span><span>:</span> <span>'</span><span>No files provided</span><span>'</span><span>},</span> <span>status</span><span>=</span><span>400</span><span>)</span> <span>extracted_text</span> <span>=</span> <span>[]</span> <span>for</span> <span>file</span> <span>in</span> <span>files</span><span>:</span> <span>if</span> <span>file</span><span>.</span><span>content_type</span> <span>not</span> <span>in</span> <span>[</span><span>'</span><span>application/pdf</span><span>'</span><span>,</span> <span>'</span><span>image/jpeg</span><span>'</span><span>,</span> <span>'</span><span>image/png</span><span>'</span><span>]:</span> <span>return</span> <span>web</span><span>.</span><span>json_response</span><span>({</span><span>'</span><span>error</span><span>'</span><span>:</span> <span>'</span><span>Invalid file type</span><span>'</span><span>},</span> <span>status</span><span>=</span><span>400</span><span>)</span> <span>text</span> <span>=</span> <span>await</span> <span>extract_text_from_file</span><span>(</span><span>file</span><span>.</span><span>file</span><span>,</span> <span>file</span><span>.</span><span>content_type</span><span>)</span> <span>extracted_text</span><span>.</span><span>append</span><span>(</span><span>text</span><span>)</span> <span>base_settings</span><span>.</span><span>context</span> <span>=</span> <span>'</span><span>\n</span><span>'</span><span>.</span><span>join</span><span>(</span><span>extracted_text</span><span>)</span> <span>return</span> <span>web</span><span>.</span><span>json_response</span><span>({</span><span>'</span><span>success</span><span>'</span><span>:</span> <span>'</span><span>Text extracted successfully</span><span>'</span><span>})</span> <span>async</span> <span>def</span> <span>chat_handler</span><span>(</span><span>request</span><span>:</span> <span>Request</span><span>)</span> <span>-></span> <span>Response</span><span>:</span> <span>"""</span><span>Handle WebSocket connections.</span><span>"""</span> <span>ws</span> <span>=</span> <span>WebSocketResponse</span><span>()</span> <span>await</span> <span>ws</span><span>.</span><span>prepare</span><span>(</span><span>request</span><span>)</span> <span>request</span><span>.</span><span>app</span><span>[</span><span>WEBSOCKETS</span><span>].</span><span>add</span><span>(</span><span>ws</span><span>)</span> <span>async</span> <span>for</span> <span>msg</span> <span>in</span> <span>ws</span><span>:</span> <span>if</span> <span>msg</span><span>.</span><span>type</span> <span>==</span> <span>web</span><span>.</span><span>WSMsgType</span><span>.</span><span>TEXT</span><span>:</span> <span>try</span><span>:</span> <span>data</span> <span>=</span> <span>json</span><span>.</span><span>loads</span><span>(</span><span>msg</span><span>.</span><span>data</span><span>)</span> <span>question_type</span> <span>=</span> <span>data</span><span>.</span><span>get</span><span>(</span><span>'</span><span>type</span><span>'</span><span>)</span> <span>question</span> <span>=</span> <span>data</span><span>.</span><span>get</span><span>(</span><span>'</span><span>question</span><span>'</span><span>,</span> <span>''</span><span>).</span><span>strip</span><span>()</span> <span>if</span> <span>not</span> <span>question</span><span>:</span> <span>await</span> <span>ws</span><span>.</span><span>send_str</span><span>(</span><span>'</span><span>Error: No question provided.</span><span>'</span><span>)</span> <span>continue</span> <span>if</span> <span>question_type</span> <span>==</span> <span>'</span><span>auto</span><span>'</span><span>:</span> <span># Stream response token by token. </span> <span>async</span> <span>for</span> <span>token</span> <span>in</span> <span>gpt_question_and_answer</span><span>(</span><span>question</span><span>):</span> <span>await</span> <span>ws</span><span>.</span><span>send_json</span><span>({</span><span>'</span><span>answer</span><span>'</span><span>:</span> <span>token</span><span>})</span> <span>elif</span> <span>question_type</span> <span>==</span> <span>'</span><span>masked</span><span>'</span><span>:</span> <span># Use squad question answering (non-streamed). </span> <span>answer</span> <span>=</span> <span>await</span> <span>squad_question_answering</span><span>(</span><span>question</span><span>)</span> <span>await</span> <span>ws</span><span>.</span><span>send_json</span><span>({</span><span>'</span><span>answer</span><span>'</span><span>:</span> <span>answer</span><span>})</span> <span>else</span><span>:</span> <span>await</span> <span>ws</span><span>.</span><span>send_str</span><span>(</span><span>'</span><span>Error: Unknown question type.</span><span>'</span><span>)</span> <span>except</span> <span>Exception</span> <span>as</span> <span>e</span><span>:</span> <span>await</span> <span>ws</span><span>.</span><span>send_str</span><span>(</span><span>f</span><span>'</span><span>Error processing message: </span><span>{</span><span>str</span><span>(</span><span>e</span><span>)</span><span>}</span><span>'</span><span>)</span> <span>elif</span> <span>msg</span><span>.</span><span>type</span> <span>==</span> <span>web</span><span>.</span><span>WSMsgType</span><span>.</span><span>ERROR</span><span>:</span> <span>request</span><span>.</span><span>app</span><span>[</span><span>WEBSOCKETS</span><span>].</span><span>remove</span><span>(</span><span>ws</span><span>)</span> <span>break</span> <span>request</span><span>.</span><span>app</span><span>[</span><span>WEBSOCKETS</span><span>].</span><span>remove</span><span>(</span><span>ws</span><span>)</span> <span>return</span> <span>ws</span> <span>def</span> <span>init_app</span><span>()</span> <span>-></span> <span>web</span><span>.</span><span>Application</span><span>:</span> <span>"""</span><span>Initialize the application.</span><span>"""</span> <span>app</span> <span>=</span> <span>web</span><span>.</span><span>Application</span><span>()</span> <span># Add routes </span> <span>app</span><span>.</span><span>router</span><span>.</span><span>add_post</span><span>(</span><span>'</span><span>/api/extract</span><span>'</span><span>,</span> <span>extract_text</span><span>)</span> <span>app</span><span>.</span><span>router</span><span>.</span><span>add_get</span><span>(</span><span>'</span><span>/chat</span><span>'</span><span>,</span> <span>chat_handler</span><span>)</span> <span># Add startup/cleanup handlers </span> <span>app</span><span>.</span><span>on_startup</span><span>.</span><span>append</span><span>(</span><span>start_background_tasks</span><span>)</span> <span>app</span><span>.</span><span>on_shutdown</span><span>.</span><span>append</span><span>(</span><span>cleanup_app</span><span>)</span> <span>return</span> <span>app</span>import json from weakref import WeakSet from aiohttp import WSCloseCode, web from aiohttp.web import Request, Response, WebSocketResponse from src.utils.auto_chat_engine import ( cleanup_auto_model, gpt_question_and_answer, prepare_auto_tokenizer_and_model, ) from src.utils.chat_engine import ( prepare_qa_tokenizer_and_model, squad_question_answering, ) from src.utils.extract import extract_text_from_file from src.utils.settings import base_settings WEBSOCKETS = web.AppKey('websockets', WeakSet[WebSocketResponse]) async def start_background_tasks(app: web.Application) -> None: """Initialize application background tasks.""" app[WEBSOCKETS] = WeakSet() await prepare_auto_tokenizer_and_model(base_settings.MODEL_NAME) await prepare_qa_tokenizer_and_model(base_settings.QA_MODEL_NAME) async def cleanup_app(app: web.Application) -> None: """Cleanup WebSocket connections on shutdown.""" # Cleanup models await cleanup_auto_model() # Close all WebSocket connections for websocket in set(app[WEBSOCKETS]): # type: ignore await websocket.close(code=WSCloseCode.GOING_AWAY, message=b'Server shutdown') async def extract_text(request: Request) -> Response: """Extract text from PDF and image files.""" data = await request.post() files = data.getall('file') if not files: return web.json_response({'error': 'No files provided'}, status=400) extracted_text = [] for file in files: if file.content_type not in ['application/pdf', 'image/jpeg', 'image/png']: return web.json_response({'error': 'Invalid file type'}, status=400) text = await extract_text_from_file(file.file, file.content_type) extracted_text.append(text) base_settings.context = '\n'.join(extracted_text) return web.json_response({'success': 'Text extracted successfully'}) async def chat_handler(request: Request) -> Response: """Handle WebSocket connections.""" ws = WebSocketResponse() await ws.prepare(request) request.app[WEBSOCKETS].add(ws) async for msg in ws: if msg.type == web.WSMsgType.TEXT: try: data = json.loads(msg.data) question_type = data.get('type') question = data.get('question', '').strip() if not question: await ws.send_str('Error: No question provided.') continue if question_type == 'auto': # Stream response token by token. async for token in gpt_question_and_answer(question): await ws.send_json({'answer': token}) elif question_type == 'masked': # Use squad question answering (non-streamed). answer = await squad_question_answering(question) await ws.send_json({'answer': answer}) else: await ws.send_str('Error: Unknown question type.') except Exception as e: await ws.send_str(f'Error processing message: {str(e)}') elif msg.type == web.WSMsgType.ERROR: request.app[WEBSOCKETS].remove(ws) break request.app[WEBSOCKETS].remove(ws) return ws def init_app() -> web.Application: """Initialize the application.""" app = web.Application() # Add routes app.router.add_post('/api/extract', extract_text) app.router.add_get('/chat', chat_handler) # Add startup/cleanup handlers app.on_startup.append(start_background_tasks) app.on_shutdown.append(cleanup_app) return app
Enter fullscreen mode Exit fullscreen mode
This __init__.py
file builds upon the structure from the recommended series. The key differences for our chatbot are:
-
start_background_tasks
(on startup) loads AI models, andcleanup_app
(on shutdown) unloads them, improving resource handling. - The
extract_text
handler accepts and processes multiple uploaded files to create a combined context. - We use direct
aiohttp
WebSocket APIs inchat_handler
. We supportauto
andmasked
types representing the use of autogenerative and masked language models respectively.
Next, let’s delve into the files within the src/utils
directory, starting with settings.py
.
<span>import</span> <span>logging</span><span>logging</span><span>.</span><span>basicConfig</span><span>(</span><span>level</span><span>=</span><span>logging</span><span>.</span><span>INFO</span><span>,</span><span>format</span><span>=</span><span>'</span><span>%(asctime)s - %(name)s - %(levelname)s - %(message)s</span><span>'</span><span>,</span><span>datefmt</span><span>=</span><span>'</span><span>%Y-%m-%d %H:%M:%S</span><span>'</span><span>,</span><span>)</span><span>logger</span> <span>=</span> <span>logging</span><span>.</span><span>getLogger</span><span>(</span><span>__name__</span><span>)</span><span>class</span> <span>Settings</span><span>:</span><span>logger</span> <span>=</span> <span>logger</span><span>SYSTEM_PROMPT</span> <span>=</span> <span>'</span><span>You are an expert Q&A assistant. Provide accurate answers based on the context.</span><span>'</span><span>context</span> <span>=</span> <span>None</span><span>MODEL_NAME</span> <span>=</span> <span>'</span><span>microsoft/Phi-3-mini-4k-instruct</span><span>'</span><span>QA_MODEL_NAME</span> <span>=</span> <span>'</span><span>deepset/roberta-base-squad2</span><span>'</span><span>base_settings</span> <span>=</span> <span>Settings</span><span>()</span><span>import</span> <span>logging</span> <span>logging</span><span>.</span><span>basicConfig</span><span>(</span> <span>level</span><span>=</span><span>logging</span><span>.</span><span>INFO</span><span>,</span> <span>format</span><span>=</span><span>'</span><span>%(asctime)s - %(name)s - %(levelname)s - %(message)s</span><span>'</span><span>,</span> <span>datefmt</span><span>=</span><span>'</span><span>%Y-%m-%d %H:%M:%S</span><span>'</span><span>,</span> <span>)</span> <span>logger</span> <span>=</span> <span>logging</span><span>.</span><span>getLogger</span><span>(</span><span>__name__</span><span>)</span> <span>class</span> <span>Settings</span><span>:</span> <span>logger</span> <span>=</span> <span>logger</span> <span>SYSTEM_PROMPT</span> <span>=</span> <span>'</span><span>You are an expert Q&A assistant. Provide accurate answers based on the context.</span><span>'</span> <span>context</span> <span>=</span> <span>None</span> <span>MODEL_NAME</span> <span>=</span> <span>'</span><span>microsoft/Phi-3-mini-4k-instruct</span><span>'</span> <span>QA_MODEL_NAME</span> <span>=</span> <span>'</span><span>deepset/roberta-base-squad2</span><span>'</span> <span>base_settings</span> <span>=</span> <span>Settings</span><span>()</span>import logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', ) logger = logging.getLogger(__name__) class Settings: logger = logger SYSTEM_PROMPT = 'You are an expert Q&A assistant. Provide accurate answers based on the context.' context = None MODEL_NAME = 'microsoft/Phi-3-mini-4k-instruct' QA_MODEL_NAME = 'deepset/roberta-base-squad2' base_settings = Settings()
Enter fullscreen mode Exit fullscreen mode
The settings.py
file primarily houses our logger
configuration and defines several important constants used in the application, including the system prompt, the default language model (MODEL_NAME
), and the question answering model (QA_MODEL_NAME
).
Let’s examine base.py
next:
<span>import</span> <span>torch</span><span>def</span> <span>get_device</span><span>()</span> <span>-></span> <span>tuple</span><span>[</span><span>torch</span><span>.</span><span>device</span><span>,</span> <span>str</span><span>]:</span><span>if</span> <span>torch</span><span>.</span><span>cuda</span><span>.</span><span>is_available</span><span>():</span><span>return</span> <span>torch</span><span>.</span><span>device</span><span>(</span><span>'</span><span>cuda</span><span>'</span><span>),</span> <span>'</span><span>CUDA (NVIDIA GPU)</span><span>'</span><span>elif</span> <span>torch</span><span>.</span><span>backends</span><span>.</span><span>mps</span><span>.</span><span>is_available</span><span>():</span><span>return</span> <span>torch</span><span>.</span><span>device</span><span>(</span><span>'</span><span>mps</span><span>'</span><span>),</span> <span>'</span><span>MPS (Apple Metal)</span><span>'</span><span>else</span><span>:</span><span>return</span> <span>torch</span><span>.</span><span>device</span><span>(</span><span>'</span><span>cpu</span><span>'</span><span>),</span> <span>'</span><span>CPU</span><span>'</span><span>async</span> <span>def</span> <span>get_stopping_strings</span><span>(</span><span>type</span><span>:</span> <span>str</span><span>,</span> <span>question</span><span>:</span> <span>str</span><span>,</span> <span>context</span><span>:</span> <span>str</span> <span>|</span> <span>None</span> <span>=</span> <span>None</span><span>)</span> <span>-></span> <span>tuple</span><span>[</span><span>str</span><span>,</span> <span>list</span><span>[</span><span>str</span><span>]]:</span><span>"""</span><span>Get stopping strings based on the question type.</span><span>"""</span><span>ctx</span> <span>=</span> <span>f</span><span>'</span><span>\n</span><span>Use this as the base Context: </span><span>{</span><span>context</span><span>}</span><span>\n</span><span>'</span> <span>if</span> <span>context</span> <span>else</span> <span>''</span><span>STOPPING_STRINGS</span> <span>=</span> <span>{</span><span>'</span><span>exNormal</span><span>'</span><span>:</span> <span>{</span><span>'</span><span>prompt</span><span>'</span><span>:</span> <span>f</span><span>'</span><span>Answer the question, succinctly, using markdown formatting and katex for math.</span><span>{</span><span>ctx</span><span>}</span><span>. Do not use unicodes, if you must, use </span><span>\\</span><span> such as </span><span>\\</span><span>u1F60A for smile.</span><span>\n</span><span>After providing your complete answer, you must conclude your response with </span><span>\n</span><span>§</span><span>\n</span><span> as the final line, with no text following it. </span><span>\n\n</span><span>Question: </span><span>{</span><span>question</span><span>}</span><span>\n\n</span><span>Answer:</span><span>'</span><span>,</span><span>'</span><span>end</span><span>'</span><span>:</span> <span>[</span><span>'</span><span>§</span><span>'</span><span>,</span> <span>'</span><span>\n</span><span>§</span><span>'</span><span>,</span> <span>'</span><span>§</span><span>\n</span><span>'</span><span>,</span> <span>'</span><span> §</span><span>'</span><span>,</span> <span>'</span><span>\n</span><span>§</span><span>\n</span><span>'</span><span>],</span><span>},</span><span>}</span><span>return</span> <span>STOPPING_STRINGS</span><span>[</span><span>type</span><span>][</span><span>'</span><span>prompt</span><span>'</span><span>],</span> <span>STOPPING_STRINGS</span><span>[</span><span>type</span><span>][</span><span>'</span><span>end</span><span>'</span><span>]</span><span>import</span> <span>torch</span> <span>def</span> <span>get_device</span><span>()</span> <span>-></span> <span>tuple</span><span>[</span><span>torch</span><span>.</span><span>device</span><span>,</span> <span>str</span><span>]:</span> <span>if</span> <span>torch</span><span>.</span><span>cuda</span><span>.</span><span>is_available</span><span>():</span> <span>return</span> <span>torch</span><span>.</span><span>device</span><span>(</span><span>'</span><span>cuda</span><span>'</span><span>),</span> <span>'</span><span>CUDA (NVIDIA GPU)</span><span>'</span> <span>elif</span> <span>torch</span><span>.</span><span>backends</span><span>.</span><span>mps</span><span>.</span><span>is_available</span><span>():</span> <span>return</span> <span>torch</span><span>.</span><span>device</span><span>(</span><span>'</span><span>mps</span><span>'</span><span>),</span> <span>'</span><span>MPS (Apple Metal)</span><span>'</span> <span>else</span><span>:</span> <span>return</span> <span>torch</span><span>.</span><span>device</span><span>(</span><span>'</span><span>cpu</span><span>'</span><span>),</span> <span>'</span><span>CPU</span><span>'</span> <span>async</span> <span>def</span> <span>get_stopping_strings</span><span>(</span><span>type</span><span>:</span> <span>str</span><span>,</span> <span>question</span><span>:</span> <span>str</span><span>,</span> <span>context</span><span>:</span> <span>str</span> <span>|</span> <span>None</span> <span>=</span> <span>None</span><span>)</span> <span>-></span> <span>tuple</span><span>[</span><span>str</span><span>,</span> <span>list</span><span>[</span><span>str</span><span>]]:</span> <span>"""</span><span>Get stopping strings based on the question type.</span><span>"""</span> <span>ctx</span> <span>=</span> <span>f</span><span>'</span><span>\n</span><span>Use this as the base Context: </span><span>{</span><span>context</span><span>}</span><span>\n</span><span>'</span> <span>if</span> <span>context</span> <span>else</span> <span>''</span> <span>STOPPING_STRINGS</span> <span>=</span> <span>{</span> <span>'</span><span>exNormal</span><span>'</span><span>:</span> <span>{</span> <span>'</span><span>prompt</span><span>'</span><span>:</span> <span>f</span><span>'</span><span>Answer the question, succinctly, using markdown formatting and katex for math.</span><span>{</span><span>ctx</span><span>}</span><span>. Do not use unicodes, if you must, use </span><span>\\</span><span> such as </span><span>\\</span><span>u1F60A for smile.</span><span>\n</span><span>After providing your complete answer, you must conclude your response with </span><span>\n</span><span>§</span><span>\n</span><span> as the final line, with no text following it. </span><span>\n\n</span><span>Question: </span><span>{</span><span>question</span><span>}</span><span>\n\n</span><span>Answer:</span><span>'</span><span>,</span> <span>'</span><span>end</span><span>'</span><span>:</span> <span>[</span><span>'</span><span>§</span><span>'</span><span>,</span> <span>'</span><span>\n</span><span>§</span><span>'</span><span>,</span> <span>'</span><span>§</span><span>\n</span><span>'</span><span>,</span> <span>'</span><span> §</span><span>'</span><span>,</span> <span>'</span><span>\n</span><span>§</span><span>\n</span><span>'</span><span>],</span> <span>},</span> <span>}</span> <span>return</span> <span>STOPPING_STRINGS</span><span>[</span><span>type</span><span>][</span><span>'</span><span>prompt</span><span>'</span><span>],</span> <span>STOPPING_STRINGS</span><span>[</span><span>type</span><span>][</span><span>'</span><span>end</span><span>'</span><span>]</span>import torch def get_device() -> tuple[torch.device, str]: if torch.cuda.is_available(): return torch.device('cuda'), 'CUDA (NVIDIA GPU)' elif torch.backends.mps.is_available(): return torch.device('mps'), 'MPS (Apple Metal)' else: return torch.device('cpu'), 'CPU' async def get_stopping_strings(type: str, question: str, context: str | None = None) -> tuple[str, list[str]]: """Get stopping strings based on the question type.""" ctx = f'\nUse this as the base Context: {context}\n' if context else '' STOPPING_STRINGS = { 'exNormal': { 'prompt': f'Answer the question, succinctly, using markdown formatting and katex for math.{ctx}. Do not use unicodes, if you must, use \\ such as \\u1F60A for smile.\nAfter providing your complete answer, you must conclude your response with \n§\n as the final line, with no text following it. \n\nQuestion: {question}\n\nAnswer:', 'end': ['§', '\n§', '§\n', ' §', '\n§\n'], }, } return STOPPING_STRINGS[type]['prompt'], STOPPING_STRINGS[type]['end']
Enter fullscreen mode Exit fullscreen mode
get_device
intelligently detects the available hardware and returns the appropriate PyTorch device (CUDA GPU, Apple Metal (MPS), or CPU) along with a descriptive string. get_stopping_strings
plays a crucial role in prompt construction for the autoregressive model. It dynamically generates a prompt and a list of stopping strings based on the question type and optional context. The prompt instructs the model to provide a concise answer using Markdown formatting and KaTeX for mathematical expressions. A key challenge with autoregressive models is their tendency to generate repetitive text. To mitigate this, the prompt explicitly instructs the model to conclude its response with the “§” symbol, an uncommon character chosen to signal the end of the answer. The application then programmatically halts text generation upon encountering this symbol. While using the model’s native end-of-sequence token would be ideal, I had no luck with it.
The auto_chat_engine.py
file is the central component responsible for generating text using the configured language model. It orchestrates the model’s loading, manages the text generation process, and handles device-specific optimizations.
<span>import</span> <span>asyncio</span><span>import</span> <span>torch</span><span>from</span> <span>transformers</span> <span>import</span> <span>AutoModelForCausalLM</span><span>,</span> <span>AutoTokenizer</span><span>,</span> <span>PreTrainedTokenizerFast</span><span>from</span> <span>src.utils.base</span> <span>import</span> <span>get_device</span><span>,</span> <span>get_stopping_strings</span><span>from</span> <span>src.utils.settings</span> <span>import</span> <span>base_settings</span><span>TOKENIZER</span><span>,</span> <span>MODEL</span> <span>=</span> <span>None</span><span>,</span> <span>None</span><span>async</span> <span>def</span> <span>prepare_auto_tokenizer_and_model</span><span>(</span><span>model_name</span><span>:</span> <span>str</span><span>)</span> <span>-></span> <span>None</span><span>:</span><span>"""</span><span> Prepare and load the tokenizer and model with hardware-specific optimizations. Args: model_name (str): The name of the model to load. Returns: tuple: The loaded tokenizer and model. </span><span>"""</span><span>global</span> <span>TOKENIZER</span><span>,</span> <span>MODEL</span><span>if</span> <span>TOKENIZER</span> <span>is</span> <span>not</span> <span>None</span> <span>and</span> <span>MODEL</span> <span>is</span> <span>not</span> <span>None</span><span>:</span><span>return</span><span>device</span><span>,</span> <span>_</span> <span>=</span> <span>get_device</span><span>()</span><span># Load tokenizer </span> <span>tokenizer</span> <span>=</span> <span>AutoTokenizer</span><span>.</span><span>from_pretrained</span><span>(</span><span>model_name</span><span>,</span> <span>padding_side</span><span>=</span><span>'</span><span>left</span><span>'</span><span>,</span> <span>truncation_side</span><span>=</span><span>'</span><span>left</span><span>'</span><span>,</span> <span>use_fast</span><span>=</span><span>True</span><span>)</span><span>tokenizer</span><span>.</span><span>pad_token</span> <span>=</span> <span>tokenizer</span><span>.</span><span>eos_token</span><span>TOKENIZER</span> <span>=</span> <span>tokenizer</span><span># Model-specific settings </span> <span>load_kwargs</span> <span>=</span> <span>{}</span><span>if</span> <span>'</span><span>phi-3</span><span>'</span> <span>in</span> <span>model_name</span><span>.</span><span>lower</span><span>():</span><span>load_kwargs</span><span>[</span><span>'</span><span>trust_remote_code</span><span>'</span><span>]</span> <span>=</span> <span>True</span><span>load_kwargs</span><span>[</span><span>'</span><span>attn_implementation</span><span>'</span><span>]</span> <span>=</span> <span>'</span><span>eager</span><span>'</span> <span># No CUDA, so use eager mode </span> <span>elif</span> <span>'</span><span>starcoder</span><span>'</span> <span>in</span> <span>model_name</span><span>.</span><span>lower</span><span>():</span><span>load_kwargs</span><span>[</span><span>'</span><span>trust_remote_code</span><span>'</span><span>]</span> <span>=</span> <span>True</span><span># Load model based on device </span> <span>if</span> <span>device</span><span>.</span><span>type</span> <span>==</span> <span>'</span><span>cpu</span><span>'</span><span>:</span><span># Use float32 for CPU and apply dynamic quantization </span> <span>model</span> <span>=</span> <span>AutoModelForCausalLM</span><span>.</span><span>from_pretrained</span><span>(</span><span>model_name</span><span>,</span><span>torch_dtype</span><span>=</span><span>torch</span><span>.</span><span>float32</span><span>,</span><span>**</span><span>load_kwargs</span><span>,</span><span>)</span><span>model</span> <span>=</span> <span>torch</span><span>.</span><span>quantization</span><span>.</span><span>quantize_dynamic</span><span>(</span><span>model</span><span>,</span> <span>{</span><span>torch</span><span>.</span><span>nn</span><span>.</span><span>Linear</span><span>},</span> <span>dtype</span><span>=</span><span>torch</span><span>.</span><span>qint8</span><span>)</span><span>elif</span> <span>device</span><span>.</span><span>type</span> <span>==</span> <span>'</span><span>mps</span><span>'</span><span>:</span><span># Use bfloat16 for MPS </span> <span>model</span> <span>=</span> <span>AutoModelForCausalLM</span><span>.</span><span>from_pretrained</span><span>(</span><span>model_name</span><span>,</span><span>torch_dtype</span><span>=</span><span>torch</span><span>.</span><span>bfloat16</span><span>,</span><span>**</span><span>load_kwargs</span><span>,</span><span>).</span><span>to</span><span>(</span><span>device</span><span>)</span><span>model</span><span>.</span><span>eval</span><span>()</span><span>MODEL</span> <span>=</span> <span>model</span><span>async</span> <span>def</span> <span>cleanup_auto_model</span><span>()</span> <span>-></span> <span>None</span><span>:</span><span>"""</span><span>Clear model from memory, important for Apple Silicon</span><span>"""</span><span>global</span> <span>MODEL</span><span>,</span> <span>TOKENIZER</span><span>if</span> <span>torch</span><span>.</span><span>backends</span><span>.</span><span>mps</span><span>.</span><span>is_available</span><span>():</span><span>torch</span><span>.</span><span>mps</span><span>.</span><span>empty_cache</span><span>()</span><span>del</span> <span>MODEL</span><span>del</span> <span>TOKENIZER</span><span>def</span> <span>top_k_top_p_filtering</span><span>(</span><span>logits</span><span>,</span> <span>top_k</span><span>=</span><span>0</span><span>,</span> <span>top_p</span><span>=</span><span>1.0</span><span>,</span> <span>filter_value</span><span>=-</span><span>float</span><span>(</span><span>"</span><span>Inf</span><span>"</span><span>)):</span><span>"""</span><span> Filter a distribution of logits using top-k and nucleus (top-p) filtering. </span><span>"""</span><span>if</span> <span>top_k</span> <span>></span> <span>0</span><span>:</span><span>values</span><span>,</span> <span>_</span> <span>=</span> <span>torch</span><span>.</span><span>topk</span><span>(</span><span>logits</span><span>,</span> <span>top_k</span><span>)</span><span>kth_value</span> <span>=</span> <span>values</span><span>[...,</span> <span>-</span><span>1</span><span>,</span> <span>None</span><span>]</span><span>logits</span> <span>=</span> <span>torch</span><span>.</span><span>where</span><span>(</span><span>logits</span> <span><</span> <span>kth_value</span><span>,</span> <span>torch</span><span>.</span><span>full_like</span><span>(</span><span>logits</span><span>,</span> <span>filter_value</span><span>),</span> <span>logits</span><span>)</span><span>if</span> <span>top_p</span> <span><</span> <span>1.0</span><span>:</span><span>sorted_logits</span><span>,</span> <span>sorted_indices</span> <span>=</span> <span>torch</span><span>.</span><span>sort</span><span>(</span><span>logits</span><span>,</span> <span>descending</span><span>=</span><span>True</span><span>)</span><span>cumulative_probs</span> <span>=</span> <span>torch</span><span>.</span><span>cumsum</span><span>(</span><span>torch</span><span>.</span><span>softmax</span><span>(</span><span>sorted_logits</span><span>,</span> <span>dim</span><span>=-</span><span>1</span><span>),</span> <span>dim</span><span>=-</span><span>1</span><span>)</span><span>sorted_indices_to_remove</span> <span>=</span> <span>cumulative_probs</span> <span>></span> <span>top_p</span><span>sorted_indices_to_remove</span><span>[...,</span> <span>1</span><span>:]</span> <span>=</span> <span>sorted_indices_to_remove</span><span>[...,</span> <span>:</span><span>-</span><span>1</span><span>].</span><span>clone</span><span>()</span><span>sorted_indices_to_remove</span><span>[...,</span> <span>0</span><span>]</span> <span>=</span> <span>0</span><span>indices_to_remove</span> <span>=</span> <span>sorted_indices</span><span>[</span><span>sorted_indices_to_remove</span><span>]</span><span>logits</span><span>[</span><span>indices_to_remove</span><span>]</span> <span>=</span> <span>filter_value</span><span>return</span> <span>logits</span><span>async</span> <span>def</span> <span>stream_chat_response</span><span>(</span><span>prompt</span><span>:</span> <span>str</span><span>,</span><span>tokenizer</span><span>:</span> <span>PreTrainedTokenizerFast</span><span>,</span><span>model</span><span>,</span><span>stopping_strings</span><span>:</span> <span>list</span><span>[</span><span>str</span><span>],</span><span>max_new_tokens</span><span>:</span> <span>int</span> <span>=</span> <span>100</span><span>,</span><span>temperature</span><span>:</span> <span>float</span> <span>=</span> <span>0.5</span><span>,</span><span>top_k</span><span>:</span> <span>int</span> <span>=</span> <span>50</span><span>,</span><span>top_p</span><span>:</span> <span>float</span> <span>=</span> <span>0.9</span><span>,</span><span>repetition_penalty</span><span>:</span> <span>float</span> <span>=</span> <span>1.5</span><span>,</span><span>repetition_window</span><span>:</span> <span>int</span> <span>=</span> <span>10</span><span>,</span><span>):</span><span>device</span> <span>=</span> <span>next</span><span>(</span><span>model</span><span>.</span><span>parameters</span><span>()).</span><span>device</span><span>eos_token_id</span> <span>=</span> <span>tokenizer</span><span>.</span><span>eos_token_id</span><span>inputs</span> <span>=</span> <span>tokenizer</span><span>(</span><span>prompt</span><span>,</span> <span>return_tensors</span><span>=</span><span>'</span><span>pt</span><span>'</span><span>)</span><span>input_ids</span> <span>=</span> <span>inputs</span><span>[</span><span>'</span><span>input_ids</span><span>'</span><span>].</span><span>to</span><span>(</span><span>device</span><span>)</span><span>attention_mask</span> <span>=</span> <span>inputs</span><span>[</span><span>'</span><span>attention_mask</span><span>'</span><span>].</span><span>to</span><span>(</span><span>device</span><span>)</span><span>full_generated_ids</span> <span>=</span> <span>input_ids</span><span>previous_generated_text</span> <span>=</span> <span>''</span><span>generated_tokens</span> <span>=</span> <span>[]</span><span>for</span> <span>_</span> <span>in</span> <span>range</span><span>(</span><span>max_new_tokens</span><span>):</span><span>with</span> <span>torch</span><span>.</span><span>no_grad</span><span>():</span><span>if</span> <span>device</span><span>.</span><span>type</span> <span>==</span> <span>'</span><span>cuda</span><span>'</span><span>:</span><span>with</span> <span>torch</span><span>.</span><span>autocast</span><span>(</span><span>device_type</span><span>=</span><span>'</span><span>cuda</span><span>'</span><span>,</span> <span>dtype</span><span>=</span><span>torch</span><span>.</span><span>float16</span><span>):</span><span>outputs</span> <span>=</span> <span>model</span><span>(</span><span>full_generated_ids</span><span>,</span> <span>attention_mask</span><span>=</span><span>attention_mask</span><span>)</span><span>elif</span> <span>device</span><span>.</span><span>type</span> <span>==</span> <span>'</span><span>mps</span><span>'</span><span>:</span><span>with</span> <span>torch</span><span>.</span><span>autocast</span><span>(</span><span>device_type</span><span>=</span><span>'</span><span>mps</span><span>'</span><span>,</span> <span>dtype</span><span>=</span><span>torch</span><span>.</span><span>bfloat16</span><span>):</span><span>outputs</span> <span>=</span> <span>model</span><span>(</span><span>full_generated_ids</span><span>,</span> <span>attention_mask</span><span>=</span><span>attention_mask</span><span>)</span><span>else</span><span>:</span> <span># CPU </span> <span>outputs</span> <span>=</span> <span>model</span><span>(</span><span>full_generated_ids</span><span>,</span> <span>attention_mask</span><span>=</span><span>attention_mask</span><span>)</span><span>next_token_logits</span> <span>=</span> <span>outputs</span><span>.</span><span>logits</span><span>[:,</span> <span>-</span><span>1</span><span>,</span> <span>:]</span><span># Apply repetition penalty </span> <span>if</span> <span>repetition_penalty</span> <span>!=</span> <span>1.0</span><span>:</span><span>for</span> <span>token</span> <span>in</span> <span>set</span><span>(</span><span>generated_tokens</span><span>[</span><span>-</span><span>20</span><span>:]):</span><span>next_token_logits</span><span>[:,</span> <span>token</span><span>]</span> <span>/=</span> <span>repetition_penalty</span><span>scaled_logits</span> <span>=</span> <span>next_token_logits</span> <span>/</span> <span>temperature</span><span>filtered_logits</span> <span>=</span> <span>top_k_top_p_filtering</span><span>(</span><span>scaled_logits</span><span>[</span><span>0</span><span>],</span> <span>top_k</span><span>=</span><span>top_k</span><span>,</span> <span>top_p</span><span>=</span><span>top_p</span><span>)</span><span>probabilities</span> <span>=</span> <span>torch</span><span>.</span><span>softmax</span><span>(</span><span>filtered_logits</span><span>,</span> <span>dim</span><span>=-</span><span>1</span><span>)</span><span>next_token</span> <span>=</span> <span>torch</span><span>.</span><span>multinomial</span><span>(</span><span>probabilities</span><span>,</span> <span>num_samples</span><span>=</span><span>1</span><span>).</span><span>unsqueeze</span><span>(</span><span>0</span><span>)</span><span>current_token</span> <span>=</span> <span>next_token</span><span>.</span><span>item</span><span>()</span><span>full_generated_ids</span> <span>=</span> <span>torch</span><span>.</span><span>cat</span><span>([</span><span>full_generated_ids</span><span>,</span> <span>next_token</span><span>],</span> <span>dim</span><span>=-</span><span>1</span><span>)</span><span>attention_mask</span> <span>=</span> <span>torch</span><span>.</span><span>cat</span><span>([</span><span>attention_mask</span><span>,</span> <span>torch</span><span>.</span><span>ones_like</span><span>(</span><span>next_token</span><span>)],</span> <span>dim</span><span>=-</span><span>1</span><span>)</span><span>generated_tokens</span><span>.</span><span>append</span><span>(</span><span>current_token</span><span>)</span><span>if</span> <span>current_token</span> <span>==</span> <span>eos_token_id</span><span>:</span><span>break</span><span>current_generated_text</span> <span>=</span> <span>tokenizer</span><span>.</span><span>decode</span><span>(</span><span>full_generated_ids</span><span>[</span><span>0</span><span>][</span><span>input_ids</span><span>.</span><span>shape</span><span>[</span><span>1</span><span>]</span> <span>:].</span><span>cpu</span><span>(),</span> <span>skip_special_tokens</span><span>=</span><span>True</span><span>)</span><span># Check stopping conditions </span> <span>if</span> <span>any</span><span>(</span><span>stop_str</span> <span>in</span> <span>current_generated_text</span> <span>for</span> <span>stop_str</span> <span>in</span> <span>stopping_strings</span><span>):</span><span>break</span><span>if </span><span>(</span><span>len</span><span>(</span><span>generated_tokens</span><span>)</span> <span>>=</span> <span>2</span> <span>*</span> <span>repetition_window</span><span>and</span> <span>generated_tokens</span><span>[</span><span>-</span><span>repetition_window</span><span>:]</span> <span>==</span> <span>generated_tokens</span><span>[</span><span>-</span><span>2</span> <span>*</span> <span>repetition_window</span> <span>:</span> <span>-</span><span>repetition_window</span><span>]</span><span>):</span><span>break</span><span>diff</span> <span>=</span> <span>current_generated_text</span><span>[</span><span>len</span><span>(</span><span>previous_generated_text</span><span>)</span> <span>:]</span><span>if</span> <span>diff</span><span>:</span><span>yield</span> <span>diff</span><span>previous_generated_text</span> <span>=</span> <span>current_generated_text</span><span>if</span> <span>device</span><span>.</span><span>type</span> <span>==</span> <span>'</span><span>mps</span><span>'</span><span>:</span><span>await</span> <span>asyncio</span><span>.</span><span>sleep</span><span>(</span><span>0.001</span><span>)</span> <span># Yield event loop for MPS </span> <span>elif</span> <span>device</span><span>.</span><span>type</span> <span>==</span> <span>'</span><span>cpu</span><span>'</span><span>:</span><span>await</span> <span>asyncio</span><span>.</span><span>sleep</span><span>(</span><span>0.005</span><span>)</span> <span># Reduce CPU load </span> <span>else</span><span>:</span><span>await</span> <span>asyncio</span><span>.</span><span>sleep</span><span>(</span><span>0</span><span>)</span><span># Return the final text </span> <span>yield</span> <span>'</span><span>[END]</span><span>'</span><span>async</span> <span>def</span> <span>gpt_question_and_answer</span><span>(</span><span>question</span><span>:</span> <span>str</span><span>):</span><span>"""</span><span>Stream an answer with repetition penalty and better stopping checks.</span><span>"""</span><span>prompt</span><span>,</span> <span>stopping_strings</span> <span>=</span> <span>await</span> <span>get_stopping_strings</span><span>(</span><span>'</span><span>exNormal</span><span>'</span><span>,</span> <span>question</span><span>,</span> <span>context</span><span>=</span><span>base_settings</span><span>.</span><span>context</span><span>)</span><span>async</span> <span>for</span> <span>chunk</span> <span>in</span> <span>stream_chat_response</span><span>(</span><span>prompt</span><span>,</span><span>TOKENIZER</span><span>,</span><span>MODEL</span><span>,</span><span>stopping_strings</span><span>=</span><span>stopping_strings</span><span>,</span><span>max_new_tokens</span><span>=</span><span>2000</span><span>,</span> <span># Increase this for longer outputs </span> <span>top_p</span><span>=</span><span>0.95</span><span>,</span><span>):</span><span>yield</span> <span>chunk</span><span>import</span> <span>asyncio</span> <span>import</span> <span>torch</span> <span>from</span> <span>transformers</span> <span>import</span> <span>AutoModelForCausalLM</span><span>,</span> <span>AutoTokenizer</span><span>,</span> <span>PreTrainedTokenizerFast</span> <span>from</span> <span>src.utils.base</span> <span>import</span> <span>get_device</span><span>,</span> <span>get_stopping_strings</span> <span>from</span> <span>src.utils.settings</span> <span>import</span> <span>base_settings</span> <span>TOKENIZER</span><span>,</span> <span>MODEL</span> <span>=</span> <span>None</span><span>,</span> <span>None</span> <span>async</span> <span>def</span> <span>prepare_auto_tokenizer_and_model</span><span>(</span><span>model_name</span><span>:</span> <span>str</span><span>)</span> <span>-></span> <span>None</span><span>:</span> <span>"""</span><span> Prepare and load the tokenizer and model with hardware-specific optimizations. Args: model_name (str): The name of the model to load. Returns: tuple: The loaded tokenizer and model. </span><span>"""</span> <span>global</span> <span>TOKENIZER</span><span>,</span> <span>MODEL</span> <span>if</span> <span>TOKENIZER</span> <span>is</span> <span>not</span> <span>None</span> <span>and</span> <span>MODEL</span> <span>is</span> <span>not</span> <span>None</span><span>:</span> <span>return</span> <span>device</span><span>,</span> <span>_</span> <span>=</span> <span>get_device</span><span>()</span> <span># Load tokenizer </span> <span>tokenizer</span> <span>=</span> <span>AutoTokenizer</span><span>.</span><span>from_pretrained</span><span>(</span><span>model_name</span><span>,</span> <span>padding_side</span><span>=</span><span>'</span><span>left</span><span>'</span><span>,</span> <span>truncation_side</span><span>=</span><span>'</span><span>left</span><span>'</span><span>,</span> <span>use_fast</span><span>=</span><span>True</span><span>)</span> <span>tokenizer</span><span>.</span><span>pad_token</span> <span>=</span> <span>tokenizer</span><span>.</span><span>eos_token</span> <span>TOKENIZER</span> <span>=</span> <span>tokenizer</span> <span># Model-specific settings </span> <span>load_kwargs</span> <span>=</span> <span>{}</span> <span>if</span> <span>'</span><span>phi-3</span><span>'</span> <span>in</span> <span>model_name</span><span>.</span><span>lower</span><span>():</span> <span>load_kwargs</span><span>[</span><span>'</span><span>trust_remote_code</span><span>'</span><span>]</span> <span>=</span> <span>True</span> <span>load_kwargs</span><span>[</span><span>'</span><span>attn_implementation</span><span>'</span><span>]</span> <span>=</span> <span>'</span><span>eager</span><span>'</span> <span># No CUDA, so use eager mode </span> <span>elif</span> <span>'</span><span>starcoder</span><span>'</span> <span>in</span> <span>model_name</span><span>.</span><span>lower</span><span>():</span> <span>load_kwargs</span><span>[</span><span>'</span><span>trust_remote_code</span><span>'</span><span>]</span> <span>=</span> <span>True</span> <span># Load model based on device </span> <span>if</span> <span>device</span><span>.</span><span>type</span> <span>==</span> <span>'</span><span>cpu</span><span>'</span><span>:</span> <span># Use float32 for CPU and apply dynamic quantization </span> <span>model</span> <span>=</span> <span>AutoModelForCausalLM</span><span>.</span><span>from_pretrained</span><span>(</span> <span>model_name</span><span>,</span> <span>torch_dtype</span><span>=</span><span>torch</span><span>.</span><span>float32</span><span>,</span> <span>**</span><span>load_kwargs</span><span>,</span> <span>)</span> <span>model</span> <span>=</span> <span>torch</span><span>.</span><span>quantization</span><span>.</span><span>quantize_dynamic</span><span>(</span><span>model</span><span>,</span> <span>{</span><span>torch</span><span>.</span><span>nn</span><span>.</span><span>Linear</span><span>},</span> <span>dtype</span><span>=</span><span>torch</span><span>.</span><span>qint8</span><span>)</span> <span>elif</span> <span>device</span><span>.</span><span>type</span> <span>==</span> <span>'</span><span>mps</span><span>'</span><span>:</span> <span># Use bfloat16 for MPS </span> <span>model</span> <span>=</span> <span>AutoModelForCausalLM</span><span>.</span><span>from_pretrained</span><span>(</span> <span>model_name</span><span>,</span> <span>torch_dtype</span><span>=</span><span>torch</span><span>.</span><span>bfloat16</span><span>,</span> <span>**</span><span>load_kwargs</span><span>,</span> <span>).</span><span>to</span><span>(</span><span>device</span><span>)</span> <span>model</span><span>.</span><span>eval</span><span>()</span> <span>MODEL</span> <span>=</span> <span>model</span> <span>async</span> <span>def</span> <span>cleanup_auto_model</span><span>()</span> <span>-></span> <span>None</span><span>:</span> <span>"""</span><span>Clear model from memory, important for Apple Silicon</span><span>"""</span> <span>global</span> <span>MODEL</span><span>,</span> <span>TOKENIZER</span> <span>if</span> <span>torch</span><span>.</span><span>backends</span><span>.</span><span>mps</span><span>.</span><span>is_available</span><span>():</span> <span>torch</span><span>.</span><span>mps</span><span>.</span><span>empty_cache</span><span>()</span> <span>del</span> <span>MODEL</span> <span>del</span> <span>TOKENIZER</span> <span>def</span> <span>top_k_top_p_filtering</span><span>(</span><span>logits</span><span>,</span> <span>top_k</span><span>=</span><span>0</span><span>,</span> <span>top_p</span><span>=</span><span>1.0</span><span>,</span> <span>filter_value</span><span>=-</span><span>float</span><span>(</span><span>"</span><span>Inf</span><span>"</span><span>)):</span> <span>"""</span><span> Filter a distribution of logits using top-k and nucleus (top-p) filtering. </span><span>"""</span> <span>if</span> <span>top_k</span> <span>></span> <span>0</span><span>:</span> <span>values</span><span>,</span> <span>_</span> <span>=</span> <span>torch</span><span>.</span><span>topk</span><span>(</span><span>logits</span><span>,</span> <span>top_k</span><span>)</span> <span>kth_value</span> <span>=</span> <span>values</span><span>[...,</span> <span>-</span><span>1</span><span>,</span> <span>None</span><span>]</span> <span>logits</span> <span>=</span> <span>torch</span><span>.</span><span>where</span><span>(</span><span>logits</span> <span><</span> <span>kth_value</span><span>,</span> <span>torch</span><span>.</span><span>full_like</span><span>(</span><span>logits</span><span>,</span> <span>filter_value</span><span>),</span> <span>logits</span><span>)</span> <span>if</span> <span>top_p</span> <span><</span> <span>1.0</span><span>:</span> <span>sorted_logits</span><span>,</span> <span>sorted_indices</span> <span>=</span> <span>torch</span><span>.</span><span>sort</span><span>(</span><span>logits</span><span>,</span> <span>descending</span><span>=</span><span>True</span><span>)</span> <span>cumulative_probs</span> <span>=</span> <span>torch</span><span>.</span><span>cumsum</span><span>(</span><span>torch</span><span>.</span><span>softmax</span><span>(</span><span>sorted_logits</span><span>,</span> <span>dim</span><span>=-</span><span>1</span><span>),</span> <span>dim</span><span>=-</span><span>1</span><span>)</span> <span>sorted_indices_to_remove</span> <span>=</span> <span>cumulative_probs</span> <span>></span> <span>top_p</span> <span>sorted_indices_to_remove</span><span>[...,</span> <span>1</span><span>:]</span> <span>=</span> <span>sorted_indices_to_remove</span><span>[...,</span> <span>:</span><span>-</span><span>1</span><span>].</span><span>clone</span><span>()</span> <span>sorted_indices_to_remove</span><span>[...,</span> <span>0</span><span>]</span> <span>=</span> <span>0</span> <span>indices_to_remove</span> <span>=</span> <span>sorted_indices</span><span>[</span><span>sorted_indices_to_remove</span><span>]</span> <span>logits</span><span>[</span><span>indices_to_remove</span><span>]</span> <span>=</span> <span>filter_value</span> <span>return</span> <span>logits</span> <span>async</span> <span>def</span> <span>stream_chat_response</span><span>(</span> <span>prompt</span><span>:</span> <span>str</span><span>,</span> <span>tokenizer</span><span>:</span> <span>PreTrainedTokenizerFast</span><span>,</span> <span>model</span><span>,</span> <span>stopping_strings</span><span>:</span> <span>list</span><span>[</span><span>str</span><span>],</span> <span>max_new_tokens</span><span>:</span> <span>int</span> <span>=</span> <span>100</span><span>,</span> <span>temperature</span><span>:</span> <span>float</span> <span>=</span> <span>0.5</span><span>,</span> <span>top_k</span><span>:</span> <span>int</span> <span>=</span> <span>50</span><span>,</span> <span>top_p</span><span>:</span> <span>float</span> <span>=</span> <span>0.9</span><span>,</span> <span>repetition_penalty</span><span>:</span> <span>float</span> <span>=</span> <span>1.5</span><span>,</span> <span>repetition_window</span><span>:</span> <span>int</span> <span>=</span> <span>10</span><span>,</span> <span>):</span> <span>device</span> <span>=</span> <span>next</span><span>(</span><span>model</span><span>.</span><span>parameters</span><span>()).</span><span>device</span> <span>eos_token_id</span> <span>=</span> <span>tokenizer</span><span>.</span><span>eos_token_id</span> <span>inputs</span> <span>=</span> <span>tokenizer</span><span>(</span><span>prompt</span><span>,</span> <span>return_tensors</span><span>=</span><span>'</span><span>pt</span><span>'</span><span>)</span> <span>input_ids</span> <span>=</span> <span>inputs</span><span>[</span><span>'</span><span>input_ids</span><span>'</span><span>].</span><span>to</span><span>(</span><span>device</span><span>)</span> <span>attention_mask</span> <span>=</span> <span>inputs</span><span>[</span><span>'</span><span>attention_mask</span><span>'</span><span>].</span><span>to</span><span>(</span><span>device</span><span>)</span> <span>full_generated_ids</span> <span>=</span> <span>input_ids</span> <span>previous_generated_text</span> <span>=</span> <span>''</span> <span>generated_tokens</span> <span>=</span> <span>[]</span> <span>for</span> <span>_</span> <span>in</span> <span>range</span><span>(</span><span>max_new_tokens</span><span>):</span> <span>with</span> <span>torch</span><span>.</span><span>no_grad</span><span>():</span> <span>if</span> <span>device</span><span>.</span><span>type</span> <span>==</span> <span>'</span><span>cuda</span><span>'</span><span>:</span> <span>with</span> <span>torch</span><span>.</span><span>autocast</span><span>(</span><span>device_type</span><span>=</span><span>'</span><span>cuda</span><span>'</span><span>,</span> <span>dtype</span><span>=</span><span>torch</span><span>.</span><span>float16</span><span>):</span> <span>outputs</span> <span>=</span> <span>model</span><span>(</span><span>full_generated_ids</span><span>,</span> <span>attention_mask</span><span>=</span><span>attention_mask</span><span>)</span> <span>elif</span> <span>device</span><span>.</span><span>type</span> <span>==</span> <span>'</span><span>mps</span><span>'</span><span>:</span> <span>with</span> <span>torch</span><span>.</span><span>autocast</span><span>(</span><span>device_type</span><span>=</span><span>'</span><span>mps</span><span>'</span><span>,</span> <span>dtype</span><span>=</span><span>torch</span><span>.</span><span>bfloat16</span><span>):</span> <span>outputs</span> <span>=</span> <span>model</span><span>(</span><span>full_generated_ids</span><span>,</span> <span>attention_mask</span><span>=</span><span>attention_mask</span><span>)</span> <span>else</span><span>:</span> <span># CPU </span> <span>outputs</span> <span>=</span> <span>model</span><span>(</span><span>full_generated_ids</span><span>,</span> <span>attention_mask</span><span>=</span><span>attention_mask</span><span>)</span> <span>next_token_logits</span> <span>=</span> <span>outputs</span><span>.</span><span>logits</span><span>[:,</span> <span>-</span><span>1</span><span>,</span> <span>:]</span> <span># Apply repetition penalty </span> <span>if</span> <span>repetition_penalty</span> <span>!=</span> <span>1.0</span><span>:</span> <span>for</span> <span>token</span> <span>in</span> <span>set</span><span>(</span><span>generated_tokens</span><span>[</span><span>-</span><span>20</span><span>:]):</span> <span>next_token_logits</span><span>[:,</span> <span>token</span><span>]</span> <span>/=</span> <span>repetition_penalty</span> <span>scaled_logits</span> <span>=</span> <span>next_token_logits</span> <span>/</span> <span>temperature</span> <span>filtered_logits</span> <span>=</span> <span>top_k_top_p_filtering</span><span>(</span><span>scaled_logits</span><span>[</span><span>0</span><span>],</span> <span>top_k</span><span>=</span><span>top_k</span><span>,</span> <span>top_p</span><span>=</span><span>top_p</span><span>)</span> <span>probabilities</span> <span>=</span> <span>torch</span><span>.</span><span>softmax</span><span>(</span><span>filtered_logits</span><span>,</span> <span>dim</span><span>=-</span><span>1</span><span>)</span> <span>next_token</span> <span>=</span> <span>torch</span><span>.</span><span>multinomial</span><span>(</span><span>probabilities</span><span>,</span> <span>num_samples</span><span>=</span><span>1</span><span>).</span><span>unsqueeze</span><span>(</span><span>0</span><span>)</span> <span>current_token</span> <span>=</span> <span>next_token</span><span>.</span><span>item</span><span>()</span> <span>full_generated_ids</span> <span>=</span> <span>torch</span><span>.</span><span>cat</span><span>([</span><span>full_generated_ids</span><span>,</span> <span>next_token</span><span>],</span> <span>dim</span><span>=-</span><span>1</span><span>)</span> <span>attention_mask</span> <span>=</span> <span>torch</span><span>.</span><span>cat</span><span>([</span><span>attention_mask</span><span>,</span> <span>torch</span><span>.</span><span>ones_like</span><span>(</span><span>next_token</span><span>)],</span> <span>dim</span><span>=-</span><span>1</span><span>)</span> <span>generated_tokens</span><span>.</span><span>append</span><span>(</span><span>current_token</span><span>)</span> <span>if</span> <span>current_token</span> <span>==</span> <span>eos_token_id</span><span>:</span> <span>break</span> <span>current_generated_text</span> <span>=</span> <span>tokenizer</span><span>.</span><span>decode</span><span>(</span> <span>full_generated_ids</span><span>[</span><span>0</span><span>][</span><span>input_ids</span><span>.</span><span>shape</span><span>[</span><span>1</span><span>]</span> <span>:].</span><span>cpu</span><span>(),</span> <span>skip_special_tokens</span><span>=</span><span>True</span> <span>)</span> <span># Check stopping conditions </span> <span>if</span> <span>any</span><span>(</span><span>stop_str</span> <span>in</span> <span>current_generated_text</span> <span>for</span> <span>stop_str</span> <span>in</span> <span>stopping_strings</span><span>):</span> <span>break</span> <span>if </span><span>(</span> <span>len</span><span>(</span><span>generated_tokens</span><span>)</span> <span>>=</span> <span>2</span> <span>*</span> <span>repetition_window</span> <span>and</span> <span>generated_tokens</span><span>[</span><span>-</span><span>repetition_window</span><span>:]</span> <span>==</span> <span>generated_tokens</span><span>[</span><span>-</span><span>2</span> <span>*</span> <span>repetition_window</span> <span>:</span> <span>-</span><span>repetition_window</span><span>]</span> <span>):</span> <span>break</span> <span>diff</span> <span>=</span> <span>current_generated_text</span><span>[</span><span>len</span><span>(</span><span>previous_generated_text</span><span>)</span> <span>:]</span> <span>if</span> <span>diff</span><span>:</span> <span>yield</span> <span>diff</span> <span>previous_generated_text</span> <span>=</span> <span>current_generated_text</span> <span>if</span> <span>device</span><span>.</span><span>type</span> <span>==</span> <span>'</span><span>mps</span><span>'</span><span>:</span> <span>await</span> <span>asyncio</span><span>.</span><span>sleep</span><span>(</span><span>0.001</span><span>)</span> <span># Yield event loop for MPS </span> <span>elif</span> <span>device</span><span>.</span><span>type</span> <span>==</span> <span>'</span><span>cpu</span><span>'</span><span>:</span> <span>await</span> <span>asyncio</span><span>.</span><span>sleep</span><span>(</span><span>0.005</span><span>)</span> <span># Reduce CPU load </span> <span>else</span><span>:</span> <span>await</span> <span>asyncio</span><span>.</span><span>sleep</span><span>(</span><span>0</span><span>)</span> <span># Return the final text </span> <span>yield</span> <span>'</span><span>[END]</span><span>'</span> <span>async</span> <span>def</span> <span>gpt_question_and_answer</span><span>(</span><span>question</span><span>:</span> <span>str</span><span>):</span> <span>"""</span><span>Stream an answer with repetition penalty and better stopping checks.</span><span>"""</span> <span>prompt</span><span>,</span> <span>stopping_strings</span> <span>=</span> <span>await</span> <span>get_stopping_strings</span><span>(</span><span>'</span><span>exNormal</span><span>'</span><span>,</span> <span>question</span><span>,</span> <span>context</span><span>=</span><span>base_settings</span><span>.</span><span>context</span><span>)</span> <span>async</span> <span>for</span> <span>chunk</span> <span>in</span> <span>stream_chat_response</span><span>(</span> <span>prompt</span><span>,</span> <span>TOKENIZER</span><span>,</span> <span>MODEL</span><span>,</span> <span>stopping_strings</span><span>=</span><span>stopping_strings</span><span>,</span> <span>max_new_tokens</span><span>=</span><span>2000</span><span>,</span> <span># Increase this for longer outputs </span> <span>top_p</span><span>=</span><span>0.95</span><span>,</span> <span>):</span> <span>yield</span> <span>chunk</span>import asyncio import torch from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast from src.utils.base import get_device, get_stopping_strings from src.utils.settings import base_settings TOKENIZER, MODEL = None, None async def prepare_auto_tokenizer_and_model(model_name: str) -> None: """ Prepare and load the tokenizer and model with hardware-specific optimizations. Args: model_name (str): The name of the model to load. Returns: tuple: The loaded tokenizer and model. """ global TOKENIZER, MODEL if TOKENIZER is not None and MODEL is not None: return device, _ = get_device() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left', truncation_side='left', use_fast=True) tokenizer.pad_token = tokenizer.eos_token TOKENIZER = tokenizer # Model-specific settings load_kwargs = {} if 'phi-3' in model_name.lower(): load_kwargs['trust_remote_code'] = True load_kwargs['attn_implementation'] = 'eager' # No CUDA, so use eager mode elif 'starcoder' in model_name.lower(): load_kwargs['trust_remote_code'] = True # Load model based on device if device.type == 'cpu': # Use float32 for CPU and apply dynamic quantization model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, **load_kwargs, ) model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) elif device.type == 'mps': # Use bfloat16 for MPS model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, **load_kwargs, ).to(device) model.eval() MODEL = model async def cleanup_auto_model() -> None: """Clear model from memory, important for Apple Silicon""" global MODEL, TOKENIZER if torch.backends.mps.is_available(): torch.mps.empty_cache() del MODEL del TOKENIZER def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf")): """ Filter a distribution of logits using top-k and nucleus (top-p) filtering. """ if top_k > 0: values, _ = torch.topk(logits, top_k) kth_value = values[..., -1, None] logits = torch.where(logits < kth_value, torch.full_like(logits, filter_value), logits) if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value return logits async def stream_chat_response( prompt: str, tokenizer: PreTrainedTokenizerFast, model, stopping_strings: list[str], max_new_tokens: int = 100, temperature: float = 0.5, top_k: int = 50, top_p: float = 0.9, repetition_penalty: float = 1.5, repetition_window: int = 10, ): device = next(model.parameters()).device eos_token_id = tokenizer.eos_token_id inputs = tokenizer(prompt, return_tensors='pt') input_ids = inputs['input_ids'].to(device) attention_mask = inputs['attention_mask'].to(device) full_generated_ids = input_ids previous_generated_text = '' generated_tokens = [] for _ in range(max_new_tokens): with torch.no_grad(): if device.type == 'cuda': with torch.autocast(device_type='cuda', dtype=torch.float16): outputs = model(full_generated_ids, attention_mask=attention_mask) elif device.type == 'mps': with torch.autocast(device_type='mps', dtype=torch.bfloat16): outputs = model(full_generated_ids, attention_mask=attention_mask) else: # CPU outputs = model(full_generated_ids, attention_mask=attention_mask) next_token_logits = outputs.logits[:, -1, :] # Apply repetition penalty if repetition_penalty != 1.0: for token in set(generated_tokens[-20:]): next_token_logits[:, token] /= repetition_penalty scaled_logits = next_token_logits / temperature filtered_logits = top_k_top_p_filtering(scaled_logits[0], top_k=top_k, top_p=top_p) probabilities = torch.softmax(filtered_logits, dim=-1) next_token = torch.multinomial(probabilities, num_samples=1).unsqueeze(0) current_token = next_token.item() full_generated_ids = torch.cat([full_generated_ids, next_token], dim=-1) attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1) generated_tokens.append(current_token) if current_token == eos_token_id: break current_generated_text = tokenizer.decode( full_generated_ids[0][input_ids.shape[1] :].cpu(), skip_special_tokens=True ) # Check stopping conditions if any(stop_str in current_generated_text for stop_str in stopping_strings): break if ( len(generated_tokens) >= 2 * repetition_window and generated_tokens[-repetition_window:] == generated_tokens[-2 * repetition_window : -repetition_window] ): break diff = current_generated_text[len(previous_generated_text) :] if diff: yield diff previous_generated_text = current_generated_text if device.type == 'mps': await asyncio.sleep(0.001) # Yield event loop for MPS elif device.type == 'cpu': await asyncio.sleep(0.005) # Reduce CPU load else: await asyncio.sleep(0) # Return the final text yield '[END]' async def gpt_question_and_answer(question: str): """Stream an answer with repetition penalty and better stopping checks.""" prompt, stopping_strings = await get_stopping_strings('exNormal', question, context=base_settings.context) async for chunk in stream_chat_response( prompt, TOKENIZER, MODEL, stopping_strings=stopping_strings, max_new_tokens=2000, # Increase this for longer outputs top_p=0.95, ): yield chunk
Enter fullscreen mode Exit fullscreen mode
prepare_auto_tokenizer_and_model
initializes the tokenizer and model. To avoid redundant or repetitive loading, we used global variables and adapted the loading process based on the available device (CPU, MPS, CUDA).
torch.quantization.quantize_dynamic
is a nifty way to quantize the model so that it will run faster and require lower memory. If your machine supports cuda
, you will need to add your block in the code.
cleanup_auto_model
primarily releases the model and tokenizer from memory. This is particularly important on memory-constrained devices like those with Apple Silicon (MPS), where memory leaks can quickly degrade performance.
top_k_top_p_filtering
implements two common techniques for controlling the diversity and quality of generated text. Top-k filtering limits the model’s choices to the k most likely tokens, while top-p (nucleus) sampling considers the smallest set of tokens whose cumulative probability exceeds p. In simple terms, they help to strike a balance between generating creative and coherent text by preventing the model from generating nonsensical or irrelevant tokens.
The stream_chat_response
function is the heart of our real-time chatbot’s text generation. It takes a prompt
, the model
, and configurations to generate streaming text output. For performance, within the generation loop, we:
- Optimize Device Usage: Move model parameters to the active device for faster access.
- Disable Gradients: Use
torch.no_grad()
to skip unnecessary gradient calculations during inference. - Enable Auto-Precision: Utilize
torch.autocast
for faster, mixed-precision operations.
The generation process automatically stops when it encounters our stopping character (code lines 151-152). To ensure smooth streaming, the function also briefly releases memory between tokens. Finally, after generating the complete response, stream_chat_response
yields [END]
to signal to the frontend that the stream is complete.
Disclaimer: Learning in Progress
As a learner in this field, I welcome any corrections or suggestions regarding my approach. Please feel free to point out any errors, and let’s learn together!
gpt_question_and_answer
serves as a high-level interface for generating answers to questions. It retrieves the appropriate prompt and stopping strings using get_stopping_strings
and then calls stream_chat_response
to generate the text. You can modify the parameters here to your taste.
Finally, let’s examine the chat_engine.py
file. This module leverages pre-trained BERT-based models for extracting answers from a given context.
<span>import</span> <span>asyncio</span><span>from</span> <span>transformers</span> <span>import</span> <span>pipeline</span><span>from</span> <span>src.utils.base</span> <span>import</span> <span>get_device</span><span>from</span> <span>src.utils.settings</span> <span>import</span> <span>base_settings</span><span>QA_PIPELINE</span> <span>=</span> <span>None</span><span>async</span> <span>def</span> <span>prepare_qa_tokenizer_and_model</span><span>(</span><span>model_name</span><span>:</span> <span>str</span><span>)</span> <span>-></span> <span>None</span><span>:</span><span>"""</span><span> Prepare and load pretrained QA pipeline with hardware-specific optimizations. </span><span>"""</span><span>global</span> <span>QA_PIPELINE</span><span>if</span> <span>QA_PIPELINE</span> <span>is</span> <span>None</span><span>:</span><span>QA_PIPELINE</span> <span>=</span> <span>pipeline</span><span>(</span><span>'</span><span>question-answering</span><span>'</span><span>,</span><span>model</span><span>=</span><span>model_name</span><span>,</span><span>tokenizer</span><span>=</span><span>model_name</span><span>,</span><span>device</span><span>=-</span><span>1</span> <span>if</span> <span>get_device</span><span>()[</span><span>0</span><span>].</span><span>type</span> <span>==</span> <span>'</span><span>cpu</span><span>'</span> <span>else</span> <span>0</span><span>,</span><span>)</span><span>base_settings</span><span>.</span><span>logger</span><span>.</span><span>info</span><span>(</span><span>'</span><span>QA pipeline initialized</span><span>'</span><span>)</span><span>async</span> <span>def</span> <span>squad_question_answering</span><span>(</span><span>question</span><span>:</span> <span>str</span><span>)</span> <span>-></span> <span>str</span><span>:</span><span>"""</span><span>Optimized QA</span><span>"""</span><span>if</span> <span>not</span> <span>base_settings</span><span>.</span><span>context</span><span>:</span><span>return</span> <span>{</span><span>'</span><span>error</span><span>'</span><span>:</span> <span>'</span><span>Please upload a document first.</span><span>'</span><span>}</span><span>if</span> <span>QA_PIPELINE</span> <span>is</span> <span>None</span><span>:</span><span>return</span> <span>{</span><span>'</span><span>error</span><span>'</span><span>:</span> <span>'</span><span>QA system not initialized</span><span>'</span><span>}</span><span>base_settings</span><span>.</span><span>logger</span><span>.</span><span>info</span><span>(</span><span>f</span><span>'</span><span>Question: </span><span>{</span><span>question</span><span>}</span><span>'</span><span>)</span><span>base_settings</span><span>.</span><span>logger</span><span>.</span><span>info</span><span>(</span><span>f</span><span>'</span><span>Context: </span><span>{</span><span>base_settings</span><span>.</span><span>context</span><span>}</span><span>'</span><span>)</span><span>try</span><span>:</span><span>result</span> <span>=</span> <span>await</span> <span>asyncio</span><span>.</span><span>to_thread</span><span>(</span><span>QA_PIPELINE</span><span>,</span><span>question</span><span>=</span><span>question</span><span>,</span><span>context</span><span>=</span><span>base_settings</span><span>.</span><span>context</span><span>,</span><span>)</span><span>return</span> <span>result</span><span>[</span><span>'</span><span>answer</span><span>'</span><span>]</span><span>except</span> <span>Exception</span> <span>as</span> <span>e</span><span>:</span><span>base_settings</span><span>.</span><span>logger</span><span>.</span><span>error</span><span>(</span><span>f</span><span>'</span><span>QA Error: </span><span>{</span><span>str</span><span>(</span><span>e</span><span>)</span><span>}</span><span>'</span><span>)</span><span>return</span> <span>{</span><span>'</span><span>error</span><span>'</span><span>:</span> <span>'</span><span>An error occurred during QA</span><span>'</span><span>}</span><span>import</span> <span>asyncio</span> <span>from</span> <span>transformers</span> <span>import</span> <span>pipeline</span> <span>from</span> <span>src.utils.base</span> <span>import</span> <span>get_device</span> <span>from</span> <span>src.utils.settings</span> <span>import</span> <span>base_settings</span> <span>QA_PIPELINE</span> <span>=</span> <span>None</span> <span>async</span> <span>def</span> <span>prepare_qa_tokenizer_and_model</span><span>(</span><span>model_name</span><span>:</span> <span>str</span><span>)</span> <span>-></span> <span>None</span><span>:</span> <span>"""</span><span> Prepare and load pretrained QA pipeline with hardware-specific optimizations. </span><span>"""</span> <span>global</span> <span>QA_PIPELINE</span> <span>if</span> <span>QA_PIPELINE</span> <span>is</span> <span>None</span><span>:</span> <span>QA_PIPELINE</span> <span>=</span> <span>pipeline</span><span>(</span> <span>'</span><span>question-answering</span><span>'</span><span>,</span> <span>model</span><span>=</span><span>model_name</span><span>,</span> <span>tokenizer</span><span>=</span><span>model_name</span><span>,</span> <span>device</span><span>=-</span><span>1</span> <span>if</span> <span>get_device</span><span>()[</span><span>0</span><span>].</span><span>type</span> <span>==</span> <span>'</span><span>cpu</span><span>'</span> <span>else</span> <span>0</span><span>,</span> <span>)</span> <span>base_settings</span><span>.</span><span>logger</span><span>.</span><span>info</span><span>(</span><span>'</span><span>QA pipeline initialized</span><span>'</span><span>)</span> <span>async</span> <span>def</span> <span>squad_question_answering</span><span>(</span><span>question</span><span>:</span> <span>str</span><span>)</span> <span>-></span> <span>str</span><span>:</span> <span>"""</span><span>Optimized QA</span><span>"""</span> <span>if</span> <span>not</span> <span>base_settings</span><span>.</span><span>context</span><span>:</span> <span>return</span> <span>{</span><span>'</span><span>error</span><span>'</span><span>:</span> <span>'</span><span>Please upload a document first.</span><span>'</span><span>}</span> <span>if</span> <span>QA_PIPELINE</span> <span>is</span> <span>None</span><span>:</span> <span>return</span> <span>{</span><span>'</span><span>error</span><span>'</span><span>:</span> <span>'</span><span>QA system not initialized</span><span>'</span><span>}</span> <span>base_settings</span><span>.</span><span>logger</span><span>.</span><span>info</span><span>(</span><span>f</span><span>'</span><span>Question: </span><span>{</span><span>question</span><span>}</span><span>'</span><span>)</span> <span>base_settings</span><span>.</span><span>logger</span><span>.</span><span>info</span><span>(</span><span>f</span><span>'</span><span>Context: </span><span>{</span><span>base_settings</span><span>.</span><span>context</span><span>}</span><span>'</span><span>)</span> <span>try</span><span>:</span> <span>result</span> <span>=</span> <span>await</span> <span>asyncio</span><span>.</span><span>to_thread</span><span>(</span> <span>QA_PIPELINE</span><span>,</span> <span>question</span><span>=</span><span>question</span><span>,</span> <span>context</span><span>=</span><span>base_settings</span><span>.</span><span>context</span><span>,</span> <span>)</span> <span>return</span> <span>result</span><span>[</span><span>'</span><span>answer</span><span>'</span><span>]</span> <span>except</span> <span>Exception</span> <span>as</span> <span>e</span><span>:</span> <span>base_settings</span><span>.</span><span>logger</span><span>.</span><span>error</span><span>(</span><span>f</span><span>'</span><span>QA Error: </span><span>{</span><span>str</span><span>(</span><span>e</span><span>)</span><span>}</span><span>'</span><span>)</span> <span>return</span> <span>{</span><span>'</span><span>error</span><span>'</span><span>:</span> <span>'</span><span>An error occurred during QA</span><span>'</span><span>}</span>import asyncio from transformers import pipeline from src.utils.base import get_device from src.utils.settings import base_settings QA_PIPELINE = None async def prepare_qa_tokenizer_and_model(model_name: str) -> None: """ Prepare and load pretrained QA pipeline with hardware-specific optimizations. """ global QA_PIPELINE if QA_PIPELINE is None: QA_PIPELINE = pipeline( 'question-answering', model=model_name, tokenizer=model_name, device=-1 if get_device()[0].type == 'cpu' else 0, ) base_settings.logger.info('QA pipeline initialized') async def squad_question_answering(question: str) -> str: """Optimized QA""" if not base_settings.context: return {'error': 'Please upload a document first.'} if QA_PIPELINE is None: return {'error': 'QA system not initialized'} base_settings.logger.info(f'Question: {question}') base_settings.logger.info(f'Context: {base_settings.context}') try: result = await asyncio.to_thread( QA_PIPELINE, question=question, context=base_settings.context, ) return result['answer'] except Exception as e: base_settings.logger.error(f'QA Error: {str(e)}') return {'error': 'An error occurred during QA'}
Enter fullscreen mode Exit fullscreen mode
prepare_qa_tokenizer_and_model
initializes the question-answering pipeline. The primary motivation behind using a global QA_PIPELINE
, as previously stated, is to avoid repeatedly loading the model, which is a resource-intensive operation. In squad_question_answering
, a key design decision here is the enforcement of a context. BERT-based models, while powerful, generally perform best when given a specific context to ground their answers. To prevent blocking the main thread, the actual question-answering process is offloaded to a separate asynchronous thread using asyncio.to_thread
.
Step 3: Building a React Frontend with Next.js and Tailwind CSS
While our server backend is functional and can be interacted with via command-line tools like wscat
(using a command such as wscat -c ws://localhost:PORT/chat
followed by a prompt like {"type": "auto", "question": "Write Python code for Fibonacci sequence"}
), this approach is not user-friendly. To make our AI chatbot accessible to a wider audience, including those unfamiliar with the terminal, we need a graphical user interface. We’ll leverage Next.js, a React framework known for its performance and developer experience, along with Tailwind CSS for styling.
Our frontend application consists of a single page with a few key components. Let’s start by examining react-frontend/src/app/layout.tsx
:
<span>import</span> <span>type</span> <span>{</span> <span>Metadata</span> <span>}</span> <span>from</span> <span>"</span><span>next</span><span>"</span><span>;</span><span>import</span> <span>{</span> <span>Geist</span><span>,</span> <span>Geist_Mono</span> <span>}</span> <span>from</span> <span>"</span><span>next/font/google</span><span>"</span><span>;</span><span>import</span> <span>"</span><span>katex/dist/katex.min.css</span><span>"</span><span>;</span><span>import</span> <span>"</span><span>./globals.css</span><span>"</span><span>;</span><span>import</span> <span>ThemeSwitcher</span> <span>from</span> <span>"</span><span>$/app/ui/layout/ThemeSwitcher</span><span>"</span><span>;</span><span>const</span> <span>geistSans</span> <span>=</span> <span>Geist</span><span>({</span><span>variable</span><span>:</span> <span>"</span><span>--font-geist-sans</span><span>"</span><span>,</span><span>subsets</span><span>:</span> <span>[</span><span>"</span><span>latin</span><span>"</span><span>],</span><span>});</span><span>const</span> <span>geistMono</span> <span>=</span> <span>Geist_Mono</span><span>({</span><span>variable</span><span>:</span> <span>"</span><span>--font-geist-mono</span><span>"</span><span>,</span><span>subsets</span><span>:</span> <span>[</span><span>"</span><span>latin</span><span>"</span><span>],</span><span>});</span><span>export</span> <span>const</span> <span>metadata</span><span>:</span> <span>Metadata</span> <span>=</span> <span>{</span><span>title</span><span>:</span> <span>"</span><span>AI Chatbot with Next.js and Python | John Owolabi Idogun</span><span>"</span><span>,</span><span>description</span><span>:</span><span>"</span><span>Build an AI chatbot with Next.js and Python using WebSockets by John Owolabi Idogun.</span><span>"</span><span>,</span><span>};</span><span>export</span> <span>default</span> <span>function</span> <span>RootLayout</span><span>({</span><span>children</span><span>,</span><span>}:</span> <span>Readonly</span><span><</span><span>{</span><span>children</span><span>:</span> <span>React</span><span>.</span><span>ReactNode</span><span>;</span><span>}</span><span>></span><span>)</span> <span>{</span><span>return </span><span>(</span><span><</span><span>html</span> <span>lang</span><span>=</span><span>"en"</span><span>></span><span><</span><span>body</span><span>className</span><span>=</span><span>{</span><span>`</span><span>${</span><span>geistSans</span><span>.</span><span>variable</span><span>}</span><span> </span><span>${</span><span>geistMono</span><span>.</span><span>variable</span><span>}</span><span> antialiased font-[family-name:var(--font-geist-sans)] bg-[#ffffff] text-[#171717] dark:bg-[#0a0a0a] dark:text-[#ededed] min-h-screen`</span><span>}</span><span>></span><span><</span><span>ThemeSwitcher</span> <span>/></span><span><</span><span>main</span> <span>className</span><span>=</span><span>"flex flex-col min-h-screen mx-auto max-w-10/12"</span><span>></span><span><</span><span>div</span> <span>className</span><span>=</span><span>"flex-1 w-full"</span><span>></span><span>{</span><span>children</span><span>}</span><span></</span><span>div</span><span>></span><span></</span><span>main</span><span>></span><span></</span><span>body</span><span>></span><span></</span><span>html</span><span>></span><span>);</span><span>}</span><span>import</span> <span>type</span> <span>{</span> <span>Metadata</span> <span>}</span> <span>from</span> <span>"</span><span>next</span><span>"</span><span>;</span> <span>import</span> <span>{</span> <span>Geist</span><span>,</span> <span>Geist_Mono</span> <span>}</span> <span>from</span> <span>"</span><span>next/font/google</span><span>"</span><span>;</span> <span>import</span> <span>"</span><span>katex/dist/katex.min.css</span><span>"</span><span>;</span> <span>import</span> <span>"</span><span>./globals.css</span><span>"</span><span>;</span> <span>import</span> <span>ThemeSwitcher</span> <span>from</span> <span>"</span><span>$/app/ui/layout/ThemeSwitcher</span><span>"</span><span>;</span> <span>const</span> <span>geistSans</span> <span>=</span> <span>Geist</span><span>({</span> <span>variable</span><span>:</span> <span>"</span><span>--font-geist-sans</span><span>"</span><span>,</span> <span>subsets</span><span>:</span> <span>[</span><span>"</span><span>latin</span><span>"</span><span>],</span> <span>});</span> <span>const</span> <span>geistMono</span> <span>=</span> <span>Geist_Mono</span><span>({</span> <span>variable</span><span>:</span> <span>"</span><span>--font-geist-mono</span><span>"</span><span>,</span> <span>subsets</span><span>:</span> <span>[</span><span>"</span><span>latin</span><span>"</span><span>],</span> <span>});</span> <span>export</span> <span>const</span> <span>metadata</span><span>:</span> <span>Metadata</span> <span>=</span> <span>{</span> <span>title</span><span>:</span> <span>"</span><span>AI Chatbot with Next.js and Python | John Owolabi Idogun</span><span>"</span><span>,</span> <span>description</span><span>:</span> <span>"</span><span>Build an AI chatbot with Next.js and Python using WebSockets by John Owolabi Idogun.</span><span>"</span><span>,</span> <span>};</span> <span>export</span> <span>default</span> <span>function</span> <span>RootLayout</span><span>({</span> <span>children</span><span>,</span> <span>}:</span> <span>Readonly</span><span><</span><span>{</span> <span>children</span><span>:</span> <span>React</span><span>.</span><span>ReactNode</span><span>;</span> <span>}</span><span>></span><span>)</span> <span>{</span> <span>return </span><span>(</span> <span><</span><span>html</span> <span>lang</span><span>=</span><span>"en"</span><span>></span> <span><</span><span>body</span> <span>className</span><span>=</span><span>{</span><span>`</span><span>${</span><span>geistSans</span><span>.</span><span>variable</span><span>}</span><span> </span><span>${</span><span>geistMono</span><span>.</span><span>variable</span><span>}</span><span> antialiased font-[family-name:var(--font-geist-sans)] bg-[#ffffff] text-[#171717] dark:bg-[#0a0a0a] dark:text-[#ededed] min-h-screen`</span><span>}</span> <span>></span> <span><</span><span>ThemeSwitcher</span> <span>/></span> <span><</span><span>main</span> <span>className</span><span>=</span><span>"flex flex-col min-h-screen mx-auto max-w-10/12"</span><span>></span> <span><</span><span>div</span> <span>className</span><span>=</span><span>"flex-1 w-full"</span><span>></span><span>{</span><span>children</span><span>}</span><span></</span><span>div</span><span>></span> <span></</span><span>main</span><span>></span> <span></</span><span>body</span><span>></span> <span></</span><span>html</span><span>></span> <span>);</span> <span>}</span>import type { Metadata } from "next"; import { Geist, Geist_Mono } from "next/font/google"; import "katex/dist/katex.min.css"; import "./globals.css"; import ThemeSwitcher from "$/app/ui/layout/ThemeSwitcher"; const geistSans = Geist({ variable: "--font-geist-sans", subsets: ["latin"], }); const geistMono = Geist_Mono({ variable: "--font-geist-mono", subsets: ["latin"], }); export const metadata: Metadata = { title: "AI Chatbot with Next.js and Python | John Owolabi Idogun", description: "Build an AI chatbot with Next.js and Python using WebSockets by John Owolabi Idogun.", }; export default function RootLayout({ children, }: Readonly<{ children: React.ReactNode; }>) { return ( <html lang="en"> <body className={`${geistSans.variable} ${geistMono.variable} antialiased font-[family-name:var(--font-geist-sans)] bg-[#ffffff] text-[#171717] dark:bg-[#0a0a0a] dark:text-[#ededed] min-h-screen`} > <ThemeSwitcher /> <main className="flex flex-col min-h-screen mx-auto max-w-10/12"> <div className="flex-1 w-full">{children}</div> </main> </body> </html> ); }
Enter fullscreen mode Exit fullscreen mode
This layout.tsx
file defines the root layout of our Next.js application. It’s largely based on the default code generated by create-next-app
and enhanced with Tailwind CSS integration. We’ve also integrated KaTeX for rendering mathematical expressions (remember to install the katex
package and its TypeScript definitions, as well as marked.js for Markdown processing). The ThemeSwitcher
component, responsible for toggling between light and dark themes, is also included.
Here’s the code for the ThemeSwitcher component:
<span>"</span><span>use client</span><span>"</span><span>;</span><span>import</span> <span>{</span> <span>useState</span><span>,</span> <span>useEffect</span> <span>}</span> <span>from</span> <span>"</span><span>react</span><span>"</span><span>;</span><span>import</span> <span>{</span> <span>MoonIcon</span><span>,</span> <span>SunIcon</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/ui/icons/base</span><span>"</span><span>;</span><span>export</span> <span>default</span> <span>function</span> <span>ThemeSwitcher</span><span>()</span> <span>{</span><span>const</span> <span>[</span><span>theme</span><span>,</span> <span>setTheme</span><span>]</span> <span>=</span> <span>useState</span><span>(</span><span>"</span><span>light</span><span>"</span><span>);</span><span>useEffect</span><span>(()</span> <span>=></span> <span>{</span><span>const</span> <span>savedTheme</span> <span>=</span> <span>localStorage</span><span>.</span><span>getItem</span><span>(</span><span>"</span><span>theme</span><span>"</span><span>)</span> <span>||</span> <span>"</span><span>light</span><span>"</span><span>;</span><span>setTheme</span><span>(</span><span>savedTheme</span><span>);</span><span>document</span><span>.</span><span>documentElement</span><span>.</span><span>classList</span><span>.</span><span>toggle</span><span>(</span><span>"</span><span>dark</span><span>"</span><span>,</span> <span>savedTheme</span> <span>===</span> <span>"</span><span>dark</span><span>"</span><span>);</span><span>},</span> <span>[]);</span><span>const</span> <span>toggleTheme</span> <span>=</span> <span>()</span> <span>=></span> <span>{</span><span>const</span> <span>newTheme</span> <span>=</span> <span>theme</span> <span>===</span> <span>"</span><span>light</span><span>"</span> <span>?</span> <span>"</span><span>dark</span><span>"</span> <span>:</span> <span>"</span><span>light</span><span>"</span><span>;</span><span>setTheme</span><span>(</span><span>newTheme</span><span>);</span><span>localStorage</span><span>.</span><span>setItem</span><span>(</span><span>"</span><span>theme</span><span>"</span><span>,</span> <span>newTheme</span><span>);</span><span>document</span><span>.</span><span>documentElement</span><span>.</span><span>classList</span><span>.</span><span>toggle</span><span>(</span><span>"</span><span>dark</span><span>"</span><span>);</span><span>};</span><span>return </span><span>(</span><span><</span><span>button</span><span>onClick</span><span>=</span><span>{</span><span>toggleTheme</span><span>}</span><span>className</span><span>=</span><span>"fixed top-4 right-4 p-2 rounded-full bg-gray-200 dark:bg-gray-800 hover:bg-gray-300 dark:hover:bg-gray-700 transition-colors"</span><span>aria-label</span><span>=</span><span>"Toggle theme"</span><span>></span><span>{</span><span>theme</span> <span>===</span> <span>"</span><span>light</span><span>"</span> <span>?</span> <span>(</span><span><</span><span>MoonIcon</span> <span>className</span><span>=</span><span>"h-5 w-5"</span> <span>/></span><span>)</span> <span>:</span> <span>(</span><span><</span><span>SunIcon</span> <span>className</span><span>=</span><span>"h-5 w-5"</span> <span>/></span><span>)</span><span>}</span><span></</span><span>button</span><span>></span><span>);</span><span>}</span><span>"</span><span>use client</span><span>"</span><span>;</span> <span>import</span> <span>{</span> <span>useState</span><span>,</span> <span>useEffect</span> <span>}</span> <span>from</span> <span>"</span><span>react</span><span>"</span><span>;</span> <span>import</span> <span>{</span> <span>MoonIcon</span><span>,</span> <span>SunIcon</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/ui/icons/base</span><span>"</span><span>;</span> <span>export</span> <span>default</span> <span>function</span> <span>ThemeSwitcher</span><span>()</span> <span>{</span> <span>const</span> <span>[</span><span>theme</span><span>,</span> <span>setTheme</span><span>]</span> <span>=</span> <span>useState</span><span>(</span><span>"</span><span>light</span><span>"</span><span>);</span> <span>useEffect</span><span>(()</span> <span>=></span> <span>{</span> <span>const</span> <span>savedTheme</span> <span>=</span> <span>localStorage</span><span>.</span><span>getItem</span><span>(</span><span>"</span><span>theme</span><span>"</span><span>)</span> <span>||</span> <span>"</span><span>light</span><span>"</span><span>;</span> <span>setTheme</span><span>(</span><span>savedTheme</span><span>);</span> <span>document</span><span>.</span><span>documentElement</span><span>.</span><span>classList</span><span>.</span><span>toggle</span><span>(</span><span>"</span><span>dark</span><span>"</span><span>,</span> <span>savedTheme</span> <span>===</span> <span>"</span><span>dark</span><span>"</span><span>);</span> <span>},</span> <span>[]);</span> <span>const</span> <span>toggleTheme</span> <span>=</span> <span>()</span> <span>=></span> <span>{</span> <span>const</span> <span>newTheme</span> <span>=</span> <span>theme</span> <span>===</span> <span>"</span><span>light</span><span>"</span> <span>?</span> <span>"</span><span>dark</span><span>"</span> <span>:</span> <span>"</span><span>light</span><span>"</span><span>;</span> <span>setTheme</span><span>(</span><span>newTheme</span><span>);</span> <span>localStorage</span><span>.</span><span>setItem</span><span>(</span><span>"</span><span>theme</span><span>"</span><span>,</span> <span>newTheme</span><span>);</span> <span>document</span><span>.</span><span>documentElement</span><span>.</span><span>classList</span><span>.</span><span>toggle</span><span>(</span><span>"</span><span>dark</span><span>"</span><span>);</span> <span>};</span> <span>return </span><span>(</span> <span><</span><span>button</span> <span>onClick</span><span>=</span><span>{</span><span>toggleTheme</span><span>}</span> <span>className</span><span>=</span><span>"fixed top-4 right-4 p-2 rounded-full bg-gray-200 dark:bg-gray-800 hover:bg-gray-300 dark:hover:bg-gray-700 transition-colors"</span> <span>aria-label</span><span>=</span><span>"Toggle theme"</span> <span>></span> <span>{</span><span>theme</span> <span>===</span> <span>"</span><span>light</span><span>"</span> <span>?</span> <span>(</span> <span><</span><span>MoonIcon</span> <span>className</span><span>=</span><span>"h-5 w-5"</span> <span>/></span> <span>)</span> <span>:</span> <span>(</span> <span><</span><span>SunIcon</span> <span>className</span><span>=</span><span>"h-5 w-5"</span> <span>/></span> <span>)</span><span>}</span> <span></</span><span>button</span><span>></span> <span>);</span> <span>}</span>"use client"; import { useState, useEffect } from "react"; import { MoonIcon, SunIcon } from "$/app/ui/icons/base"; export default function ThemeSwitcher() { const [theme, setTheme] = useState("light"); useEffect(() => { const savedTheme = localStorage.getItem("theme") || "light"; setTheme(savedTheme); document.documentElement.classList.toggle("dark", savedTheme === "dark"); }, []); const toggleTheme = () => { const newTheme = theme === "light" ? "dark" : "light"; setTheme(newTheme); localStorage.setItem("theme", newTheme); document.documentElement.classList.toggle("dark"); }; return ( <button onClick={toggleTheme} className="fixed top-4 right-4 p-2 rounded-full bg-gray-200 dark:bg-gray-800 hover:bg-gray-300 dark:hover:bg-gray-700 transition-colors" aria-label="Toggle theme" > {theme === "light" ? ( <MoonIcon className="h-5 w-5" /> ) : ( <SunIcon className="h-5 w-5" /> )} </button> ); }
Enter fullscreen mode Exit fullscreen mode
For Svelte developers, React’s useEffect
hook serves a similar purpose to Svelte 5’s $effect
block. Both allow you to run code in response to changes (in useEffect
you need to list the dependencies, you don’t need that in Svelte), and critically, both provide a safe context for interacting with HTML elements.
Now is the turn is react-frontend/src/app/page.tsx
:
<span>"</span><span>use client</span><span>"</span><span>;</span><span>import</span> <span>{</span> <span>useCallback</span><span>,</span> <span>useState</span> <span>}</span> <span>from</span> <span>"</span><span>react</span><span>"</span><span>;</span><span>import</span> <span>ChatContainer</span> <span>from</span> <span>"</span><span>$/app/ui/chat/ChatContainer</span><span>"</span><span>;</span><span>import</span> <span>ChatInput</span> <span>from</span> <span>"</span><span>$/app/ui/chat/ChatInput</span><span>"</span><span>;</span><span>import</span> <span>{</span> <span>Logo</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/ui/icons/base</span><span>"</span><span>;</span><span>import</span> <span>{</span> <span>useWebSocket</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/lib/hooks/useWebSocket</span><span>"</span><span>;</span><span>import</span> <span>{</span> <span>Message</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/lib/types</span><span>"</span><span>;</span><span>import</span> <span>{</span> <span>useAutoScroll</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/lib/hooks/useAutoScroll</span><span>"</span><span>;</span><span>import</span> <span>{</span> <span>BASE_WS_URL</span> <span>}</span> <span>from</span> <span>"</span><span>./lib/constants</span><span>"</span><span>;</span><span>export</span> <span>default</span> <span>function</span> <span>Home</span><span>()</span> <span>{</span><span>const</span> <span>[</span><span>messages</span><span>,</span> <span>setMessages</span><span>]</span> <span>=</span> <span>useState</span><span><</span><span>Message</span><span>[]</span><span>></span><span>([]);</span><span>const</span> <span>scrollRef</span> <span>=</span> <span>useAutoScroll</span><span><</span><span>HTMLDivElement</span><span>></span><span>(</span><span>messages</span><span>);</span><span>const</span> <span>handleSend</span> <span>=</span> <span>(</span><span>message</span><span>:</span> <span>string</span><span>,</span> <span>type</span><span>:</span> <span>string</span><span>)</span> <span>=></span> <span>{</span><span>const</span> <span>newMessage</span> <span>=</span> <span>{</span><span>id</span><span>:</span> <span>Date</span><span>.</span><span>now</span><span>(),</span><span>text</span><span>:</span> <span>message</span><span>,</span><span>sender</span><span>:</span> <span>"</span><span>user</span><span>"</span> <span>as</span> <span>const</span><span>,</span><span>};</span><span>// Add a loading message for the bot</span><span>const</span> <span>loadingMessage</span> <span>=</span> <span>{</span><span>id</span><span>:</span> <span>Date</span><span>.</span><span>now</span><span>()</span> <span>+</span> <span>1</span><span>,</span><span>text</span><span>:</span> <span>""</span><span>,</span><span>sender</span><span>:</span> <span>"</span><span>bot</span><span>"</span> <span>as</span> <span>const</span><span>,</span><span>loading</span><span>:</span> <span>true</span><span>,</span><span>complete</span><span>:</span> <span>false</span><span>,</span><span>};</span><span>setMessages</span><span>((</span><span>prev</span><span>)</span> <span>=></span> <span>[...</span><span>prev</span><span>,</span> <span>newMessage</span><span>,</span> <span>loadingMessage</span><span>]);</span><span>sendMessage</span><span>(</span><span>message</span><span>,</span> <span>type</span><span>);</span><span>};</span><span>const</span> <span>handleBotMessage</span> <span>=</span> <span>useCallback</span><span>(</span><span>(</span><span>content</span><span>:</span> <span>string</span><span>,</span> <span>isComplete</span><span>?:</span> <span>boolean</span><span>)</span> <span>=></span> <span>{</span><span>setMessages</span><span>((</span><span>prev</span><span>)</span> <span>=></span> <span>{</span><span>const</span> <span>lastMessage</span> <span>=</span> <span>prev</span><span>[</span><span>prev</span><span>.</span><span>length</span> <span>-</span> <span>1</span><span>];</span><span>if </span><span>(</span><span>isComplete</span> <span>&&</span> <span>lastMessage</span><span>?.</span><span>sender</span> <span>===</span> <span>"</span><span>bot</span><span>"</span><span>)</span> <span>{</span><span>const</span> <span>updatedMessages</span> <span>=</span> <span>[...</span><span>prev</span><span>];</span><span>updatedMessages</span><span>[</span><span>prev</span><span>.</span><span>length</span> <span>-</span> <span>1</span><span>]</span> <span>=</span> <span>{</span><span>...</span><span>lastMessage</span><span>,</span><span>complete</span><span>:</span> <span>true</span><span>,</span><span>loading</span><span>:</span> <span>false</span><span>,</span><span>};</span><span>return</span> <span>updatedMessages</span><span>;</span><span>}</span><span>if </span><span>(</span><span>!</span><span>content</span><span>)</span> <span>return</span> <span>prev</span><span>;</span><span>if </span><span>(</span><span>!</span><span>lastMessage</span> <span>||</span><span>lastMessage</span><span>.</span><span>sender</span> <span>!==</span> <span>"</span><span>bot</span><span>"</span> <span>||</span><span>lastMessage</span><span>.</span><span>complete</span><span>)</span> <span>{</span><span>return</span> <span>[</span><span>...</span><span>prev</span><span>,</span><span>{</span><span>id</span><span>:</span> <span>Date</span><span>.</span><span>now</span><span>(),</span><span>text</span><span>:</span> <span>content</span><span>,</span><span>sender</span><span>:</span> <span>"</span><span>bot</span><span>"</span><span>,</span><span>complete</span><span>:</span> <span>false</span><span>,</span><span>loading</span><span>:</span> <span>false</span><span>,</span><span>},</span><span>];</span><span>}</span><span>const</span> <span>updatedMessages</span> <span>=</span> <span>[...</span><span>prev</span><span>];</span><span>updatedMessages</span><span>[</span><span>prev</span><span>.</span><span>length</span> <span>-</span> <span>1</span><span>]</span> <span>=</span> <span>{</span><span>...</span><span>lastMessage</span><span>,</span><span>text</span><span>:</span> <span>lastMessage</span><span>.</span><span>text</span> <span>+</span> <span>content</span><span>,</span><span>loading</span><span>:</span> <span>false</span><span>,</span><span>};</span><span>return</span> <span>updatedMessages</span><span>;</span><span>});</span><span>},</span><span>[]</span><span>);</span><span>const</span> <span>{</span> <span>sendMessage</span><span>,</span> <span>isConnected</span><span>,</span> <span>error</span> <span>}</span> <span>=</span> <span>useWebSocket</span><span>(</span><span>`</span><span>${</span><span>BASE_WS_URL</span><span>}</span><span>/chat`</span><span>,</span><span>handleBotMessage</span><span>);</span><span>return </span><span>(</span><span><</span><span>div</span> <span>className</span><span>=</span><span>"flex flex-col flex-1 h-[calc(100vh-10rem)]"</span><span>></span><span>{</span><span>!</span><span>messages</span><span>.</span><span>length</span> <span>?</span> <span>(</span><span><</span><span>div</span> <span>className</span><span>=</span><span>"flex flex-col items-center justify-center flex-1"</span><span>></span><span><</span><span>div</span> <span>className</span><span>=</span><span>"flex flex-col items-center max-w-3xl w-full mx-auto"</span><span>></span><span><</span><span>Logo</span> <span>className</span><span>=</span><span>"h-16 w-16 mb-6 text-[#171717] dark:text-[#ededed]"</span> <span>/></span><span><</span><span>p</span> <span>className</span><span>=</span><span>"text-gray-600 dark:text-gray-400 mb-8"</span><span>></span>Ask me anything! I'm here to help.<span></</span><span>p</span><span>></span><span><</span><span>div</span> <span>className</span><span>=</span><span>"w-full"</span><span>></span><span><</span><span>ChatInput</span> <span>onSend</span><span>=</span><span>{</span><span>handleSend</span><span>}</span> <span>messageCount</span><span>=</span><span>{</span><span>messages</span><span>.</span><span>length</span><span>}</span> <span>/></span><span></</span><span>div</span><span>></span><span></</span><span>div</span><span>></span><span></</span><span>div</span><span>></span><span>)</span> <span>:</span> <span>(</span><span><></span><span><</span><span>div</span> <span>ref</span><span>=</span><span>{</span><span>scrollRef</span><span>}</span> <span>className</span><span>=</span><span>"flex-1 overflow-y-auto no-scrollbar"</span><span>></span><span><</span><span>ChatContainer</span> <span>messages</span><span>=</span><span>{</span><span>messages</span><span>}</span> <span>/></span><span></</span><span>div</span><span>></span><span><</span><span>ChatInput</span> <span>onSend</span><span>=</span><span>{</span><span>handleSend</span><span>}</span> <span>messageCount</span><span>=</span><span>{</span><span>messages</span><span>.</span><span>length</span><span>}</span> <span>/></span><span></></span><span>)</span><span>}</span><span></</span><span>div</span><span>></span><span>);</span><span>}</span><span>"</span><span>use client</span><span>"</span><span>;</span> <span>import</span> <span>{</span> <span>useCallback</span><span>,</span> <span>useState</span> <span>}</span> <span>from</span> <span>"</span><span>react</span><span>"</span><span>;</span> <span>import</span> <span>ChatContainer</span> <span>from</span> <span>"</span><span>$/app/ui/chat/ChatContainer</span><span>"</span><span>;</span> <span>import</span> <span>ChatInput</span> <span>from</span> <span>"</span><span>$/app/ui/chat/ChatInput</span><span>"</span><span>;</span> <span>import</span> <span>{</span> <span>Logo</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/ui/icons/base</span><span>"</span><span>;</span> <span>import</span> <span>{</span> <span>useWebSocket</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/lib/hooks/useWebSocket</span><span>"</span><span>;</span> <span>import</span> <span>{</span> <span>Message</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/lib/types</span><span>"</span><span>;</span> <span>import</span> <span>{</span> <span>useAutoScroll</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/lib/hooks/useAutoScroll</span><span>"</span><span>;</span> <span>import</span> <span>{</span> <span>BASE_WS_URL</span> <span>}</span> <span>from</span> <span>"</span><span>./lib/constants</span><span>"</span><span>;</span> <span>export</span> <span>default</span> <span>function</span> <span>Home</span><span>()</span> <span>{</span> <span>const</span> <span>[</span><span>messages</span><span>,</span> <span>setMessages</span><span>]</span> <span>=</span> <span>useState</span><span><</span><span>Message</span><span>[]</span><span>></span><span>([]);</span> <span>const</span> <span>scrollRef</span> <span>=</span> <span>useAutoScroll</span><span><</span><span>HTMLDivElement</span><span>></span><span>(</span><span>messages</span><span>);</span> <span>const</span> <span>handleSend</span> <span>=</span> <span>(</span><span>message</span><span>:</span> <span>string</span><span>,</span> <span>type</span><span>:</span> <span>string</span><span>)</span> <span>=></span> <span>{</span> <span>const</span> <span>newMessage</span> <span>=</span> <span>{</span> <span>id</span><span>:</span> <span>Date</span><span>.</span><span>now</span><span>(),</span> <span>text</span><span>:</span> <span>message</span><span>,</span> <span>sender</span><span>:</span> <span>"</span><span>user</span><span>"</span> <span>as</span> <span>const</span><span>,</span> <span>};</span> <span>// Add a loading message for the bot</span> <span>const</span> <span>loadingMessage</span> <span>=</span> <span>{</span> <span>id</span><span>:</span> <span>Date</span><span>.</span><span>now</span><span>()</span> <span>+</span> <span>1</span><span>,</span> <span>text</span><span>:</span> <span>""</span><span>,</span> <span>sender</span><span>:</span> <span>"</span><span>bot</span><span>"</span> <span>as</span> <span>const</span><span>,</span> <span>loading</span><span>:</span> <span>true</span><span>,</span> <span>complete</span><span>:</span> <span>false</span><span>,</span> <span>};</span> <span>setMessages</span><span>((</span><span>prev</span><span>)</span> <span>=></span> <span>[...</span><span>prev</span><span>,</span> <span>newMessage</span><span>,</span> <span>loadingMessage</span><span>]);</span> <span>sendMessage</span><span>(</span><span>message</span><span>,</span> <span>type</span><span>);</span> <span>};</span> <span>const</span> <span>handleBotMessage</span> <span>=</span> <span>useCallback</span><span>(</span> <span>(</span><span>content</span><span>:</span> <span>string</span><span>,</span> <span>isComplete</span><span>?:</span> <span>boolean</span><span>)</span> <span>=></span> <span>{</span> <span>setMessages</span><span>((</span><span>prev</span><span>)</span> <span>=></span> <span>{</span> <span>const</span> <span>lastMessage</span> <span>=</span> <span>prev</span><span>[</span><span>prev</span><span>.</span><span>length</span> <span>-</span> <span>1</span><span>];</span> <span>if </span><span>(</span><span>isComplete</span> <span>&&</span> <span>lastMessage</span><span>?.</span><span>sender</span> <span>===</span> <span>"</span><span>bot</span><span>"</span><span>)</span> <span>{</span> <span>const</span> <span>updatedMessages</span> <span>=</span> <span>[...</span><span>prev</span><span>];</span> <span>updatedMessages</span><span>[</span><span>prev</span><span>.</span><span>length</span> <span>-</span> <span>1</span><span>]</span> <span>=</span> <span>{</span> <span>...</span><span>lastMessage</span><span>,</span> <span>complete</span><span>:</span> <span>true</span><span>,</span> <span>loading</span><span>:</span> <span>false</span><span>,</span> <span>};</span> <span>return</span> <span>updatedMessages</span><span>;</span> <span>}</span> <span>if </span><span>(</span><span>!</span><span>content</span><span>)</span> <span>return</span> <span>prev</span><span>;</span> <span>if </span><span>(</span> <span>!</span><span>lastMessage</span> <span>||</span> <span>lastMessage</span><span>.</span><span>sender</span> <span>!==</span> <span>"</span><span>bot</span><span>"</span> <span>||</span> <span>lastMessage</span><span>.</span><span>complete</span> <span>)</span> <span>{</span> <span>return</span> <span>[</span> <span>...</span><span>prev</span><span>,</span> <span>{</span> <span>id</span><span>:</span> <span>Date</span><span>.</span><span>now</span><span>(),</span> <span>text</span><span>:</span> <span>content</span><span>,</span> <span>sender</span><span>:</span> <span>"</span><span>bot</span><span>"</span><span>,</span> <span>complete</span><span>:</span> <span>false</span><span>,</span> <span>loading</span><span>:</span> <span>false</span><span>,</span> <span>},</span> <span>];</span> <span>}</span> <span>const</span> <span>updatedMessages</span> <span>=</span> <span>[...</span><span>prev</span><span>];</span> <span>updatedMessages</span><span>[</span><span>prev</span><span>.</span><span>length</span> <span>-</span> <span>1</span><span>]</span> <span>=</span> <span>{</span> <span>...</span><span>lastMessage</span><span>,</span> <span>text</span><span>:</span> <span>lastMessage</span><span>.</span><span>text</span> <span>+</span> <span>content</span><span>,</span> <span>loading</span><span>:</span> <span>false</span><span>,</span> <span>};</span> <span>return</span> <span>updatedMessages</span><span>;</span> <span>});</span> <span>},</span> <span>[]</span> <span>);</span> <span>const</span> <span>{</span> <span>sendMessage</span><span>,</span> <span>isConnected</span><span>,</span> <span>error</span> <span>}</span> <span>=</span> <span>useWebSocket</span><span>(</span> <span>`</span><span>${</span><span>BASE_WS_URL</span><span>}</span><span>/chat`</span><span>,</span> <span>handleBotMessage</span> <span>);</span> <span>return </span><span>(</span> <span><</span><span>div</span> <span>className</span><span>=</span><span>"flex flex-col flex-1 h-[calc(100vh-10rem)]"</span><span>></span> <span>{</span><span>!</span><span>messages</span><span>.</span><span>length</span> <span>?</span> <span>(</span> <span><</span><span>div</span> <span>className</span><span>=</span><span>"flex flex-col items-center justify-center flex-1"</span><span>></span> <span><</span><span>div</span> <span>className</span><span>=</span><span>"flex flex-col items-center max-w-3xl w-full mx-auto"</span><span>></span> <span><</span><span>Logo</span> <span>className</span><span>=</span><span>"h-16 w-16 mb-6 text-[#171717] dark:text-[#ededed]"</span> <span>/></span> <span><</span><span>p</span> <span>className</span><span>=</span><span>"text-gray-600 dark:text-gray-400 mb-8"</span><span>></span> Ask me anything! I'm here to help. <span></</span><span>p</span><span>></span> <span><</span><span>div</span> <span>className</span><span>=</span><span>"w-full"</span><span>></span> <span><</span><span>ChatInput</span> <span>onSend</span><span>=</span><span>{</span><span>handleSend</span><span>}</span> <span>messageCount</span><span>=</span><span>{</span><span>messages</span><span>.</span><span>length</span><span>}</span> <span>/></span> <span></</span><span>div</span><span>></span> <span></</span><span>div</span><span>></span> <span></</span><span>div</span><span>></span> <span>)</span> <span>:</span> <span>(</span> <span><></span> <span><</span><span>div</span> <span>ref</span><span>=</span><span>{</span><span>scrollRef</span><span>}</span> <span>className</span><span>=</span><span>"flex-1 overflow-y-auto no-scrollbar"</span><span>></span> <span><</span><span>ChatContainer</span> <span>messages</span><span>=</span><span>{</span><span>messages</span><span>}</span> <span>/></span> <span></</span><span>div</span><span>></span> <span><</span><span>ChatInput</span> <span>onSend</span><span>=</span><span>{</span><span>handleSend</span><span>}</span> <span>messageCount</span><span>=</span><span>{</span><span>messages</span><span>.</span><span>length</span><span>}</span> <span>/></span> <span></></span> <span>)</span><span>}</span> <span></</span><span>div</span><span>></span> <span>);</span> <span>}</span>"use client"; import { useCallback, useState } from "react"; import ChatContainer from "$/app/ui/chat/ChatContainer"; import ChatInput from "$/app/ui/chat/ChatInput"; import { Logo } from "$/app/ui/icons/base"; import { useWebSocket } from "$/app/lib/hooks/useWebSocket"; import { Message } from "$/app/lib/types"; import { useAutoScroll } from "$/app/lib/hooks/useAutoScroll"; import { BASE_WS_URL } from "./lib/constants"; export default function Home() { const [messages, setMessages] = useState<Message[]>([]); const scrollRef = useAutoScroll<HTMLDivElement>(messages); const handleSend = (message: string, type: string) => { const newMessage = { id: Date.now(), text: message, sender: "user" as const, }; // Add a loading message for the bot const loadingMessage = { id: Date.now() + 1, text: "", sender: "bot" as const, loading: true, complete: false, }; setMessages((prev) => [...prev, newMessage, loadingMessage]); sendMessage(message, type); }; const handleBotMessage = useCallback( (content: string, isComplete?: boolean) => { setMessages((prev) => { const lastMessage = prev[prev.length - 1]; if (isComplete && lastMessage?.sender === "bot") { const updatedMessages = [...prev]; updatedMessages[prev.length - 1] = { ...lastMessage, complete: true, loading: false, }; return updatedMessages; } if (!content) return prev; if ( !lastMessage || lastMessage.sender !== "bot" || lastMessage.complete ) { return [ ...prev, { id: Date.now(), text: content, sender: "bot", complete: false, loading: false, }, ]; } const updatedMessages = [...prev]; updatedMessages[prev.length - 1] = { ...lastMessage, text: lastMessage.text + content, loading: false, }; return updatedMessages; }); }, [] ); const { sendMessage, isConnected, error } = useWebSocket( `${BASE_WS_URL}/chat`, handleBotMessage ); return ( <div className="flex flex-col flex-1 h-[calc(100vh-10rem)]"> {!messages.length ? ( <div className="flex flex-col items-center justify-center flex-1"> <div className="flex flex-col items-center max-w-3xl w-full mx-auto"> <Logo className="h-16 w-16 mb-6 text-[#171717] dark:text-[#ededed]" /> <p className="text-gray-600 dark:text-gray-400 mb-8"> Ask me anything! I'm here to help. </p> <div className="w-full"> <ChatInput onSend={handleSend} messageCount={messages.length} /> </div> </div> </div> ) : ( <> <div ref={scrollRef} className="flex-1 overflow-y-auto no-scrollbar"> <ChatContainer messages={messages} /> </div> <ChatInput onSend={handleSend} messageCount={messages.length} /> </> )} </div> ); }
Enter fullscreen mode Exit fullscreen mode
This file represents the main page component of our chat application, housing the core logic for message handling and UI rendering. The Home
component manages the chat interface, including displaying messages, handling user input, and communicating with the backend. The handleSend
function adds a user message and a loading message to the state, while handleBotMessage
updates the UI with the bot’s responses as they stream in.
The custom useWebSocket
hook encapsulates all the WebSocket logic, including connection management, message handling, and error handling:
<span>import</span> <span>{</span> <span>useState</span><span>,</span> <span>useEffect</span><span>,</span> <span>useCallback</span> <span>}</span> <span>from</span> <span>"</span><span>react</span><span>"</span><span>;</span><span>type</span> <span>WebSocketMessage</span> <span>=</span> <span>{</span><span>type</span><span>:</span> <span>string</span><span>;</span><span>question</span><span>:</span> <span>string</span><span>;</span><span>};</span><span>export</span> <span>const</span> <span>useWebSocket</span> <span>=</span> <span>(</span><span>url</span><span>:</span> <span>string</span><span>,</span><span>onMessage</span><span>:</span> <span>(</span><span>content</span><span>:</span> <span>string</span><span>,</span> <span>isComplete</span><span>?:</span> <span>boolean</span><span>)</span> <span>=></span> <span>void</span><span>)</span> <span>=></span> <span>{</span><span>const</span> <span>[</span><span>socket</span><span>,</span> <span>setSocket</span><span>]</span> <span>=</span> <span>useState</span><span><</span><span>WebSocket</span> <span>|</span> <span>null</span><span>></span><span>(</span><span>null</span><span>);</span><span>const</span> <span>[</span><span>isConnected</span><span>,</span> <span>setIsConnected</span><span>]</span> <span>=</span> <span>useState</span><span>(</span><span>false</span><span>);</span><span>const</span> <span>[</span><span>error</span><span>,</span> <span>setError</span><span>]</span> <span>=</span> <span>useState</span><span><</span><span>string</span> <span>|</span> <span>null</span><span>></span><span>(</span><span>null</span><span>);</span><span>useEffect</span><span>(()</span> <span>=></span> <span>{</span><span>const</span> <span>ws</span> <span>=</span> <span>new</span> <span>WebSocket</span><span>(</span><span>url</span><span>);</span><span>ws</span><span>.</span><span>onopen</span> <span>=</span> <span>()</span> <span>=></span> <span>{</span><span>console</span><span>.</span><span>log</span><span>(</span><span>"</span><span>WebSocket Connected</span><span>"</span><span>);</span><span>setIsConnected</span><span>(</span><span>true</span><span>);</span><span>setError</span><span>(</span><span>null</span><span>);</span><span>};</span><span>ws</span><span>.</span><span>onmessage</span> <span>=</span> <span>(</span><span>event</span><span>)</span> <span>=></span> <span>{</span><span>try</span> <span>{</span><span>const</span> <span>data</span> <span>=</span> <span>JSON</span><span>.</span><span>parse</span><span>(</span><span>event</span><span>.</span><span>data</span><span>);</span><span>// Handle the streaming response</span><span>if </span><span>(</span><span>data</span><span>.</span><span>answer</span> <span>===</span> <span>"</span><span>[END]</span><span>"</span><span>)</span> <span>{</span><span>onMessage</span><span>(</span><span>""</span><span>,</span> <span>true</span><span>);</span> <span>// Signal completion</span><span>}</span> <span>else</span> <span>{</span><span>onMessage</span><span>(</span><span>data</span><span>.</span><span>answer</span><span>,</span> <span>false</span><span>);</span><span>}</span><span>}</span> <span>catch </span><span>(</span><span>e</span><span>)</span> <span>{</span><span>console</span><span>.</span><span>error</span><span>(</span><span>"</span><span>Error parsing message:</span><span>"</span><span>,</span> <span>e</span><span>);</span><span>}</span><span>};</span><span>ws</span><span>.</span><span>onclose</span> <span>=</span> <span>()</span> <span>=></span> <span>{</span><span>console</span><span>.</span><span>log</span><span>(</span><span>"</span><span>WebSocket Disconnected</span><span>"</span><span>);</span><span>setIsConnected</span><span>(</span><span>false</span><span>);</span><span>setTimeout</span><span>(()</span> <span>=></span> <span>setSocket</span><span>(</span><span>new</span> <span>WebSocket</span><span>(</span><span>url</span><span>)),</span> <span>3000</span><span>);</span><span>};</span><span>ws</span><span>.</span><span>onerror</span> <span>=</span> <span>(</span><span>event</span><span>)</span> <span>=></span> <span>{</span><span>console</span><span>.</span><span>error</span><span>(</span><span>"</span><span>WebSocket error:</span><span>"</span><span>,</span> <span>event</span><span>);</span><span>setError</span><span>(</span><span>"</span><span>WebSocket error occurred</span><span>"</span><span>);</span><span>};</span><span>setSocket</span><span>(</span><span>ws</span><span>);</span><span>return </span><span>()</span> <span>=></span> <span>{</span><span>ws</span><span>.</span><span>close</span><span>();</span><span>};</span><span>},</span> <span>[</span><span>url</span><span>,</span> <span>onMessage</span><span>]);</span><span>const</span> <span>sendMessage</span> <span>=</span> <span>useCallback</span><span>(</span><span>(</span><span>message</span><span>:</span> <span>string</span><span>,</span> <span>type</span><span>:</span> <span>string</span><span>)</span> <span>=></span> <span>{</span><span>if </span><span>(</span><span>socket</span> <span>&&</span> <span>isConnected</span><span>)</span> <span>{</span><span>const</span> <span>payload</span><span>:</span> <span>WebSocketMessage</span> <span>=</span> <span>{</span><span>type</span><span>:</span> <span>type</span><span>,</span><span>question</span><span>:</span> <span>message</span><span>,</span><span>};</span><span>socket</span><span>.</span><span>send</span><span>(</span><span>JSON</span><span>.</span><span>stringify</span><span>(</span><span>payload</span><span>));</span><span>}</span> <span>else</span> <span>{</span><span>console</span><span>.</span><span>log</span><span>(</span><span>"</span><span>Socket not ready:</span><span>"</span><span>,</span> <span>{</span><span>isConnected</span><span>,</span><span>socketExists</span><span>:</span> <span>!!</span><span>socket</span><span>,</span><span>});</span><span>}</span><span>},</span><span>[</span><span>socket</span><span>,</span> <span>isConnected</span><span>]</span><span>);</span><span>return</span> <span>{</span> <span>sendMessage</span><span>,</span> <span>isConnected</span><span>,</span> <span>error</span> <span>};</span><span>};</span><span>import</span> <span>{</span> <span>useState</span><span>,</span> <span>useEffect</span><span>,</span> <span>useCallback</span> <span>}</span> <span>from</span> <span>"</span><span>react</span><span>"</span><span>;</span> <span>type</span> <span>WebSocketMessage</span> <span>=</span> <span>{</span> <span>type</span><span>:</span> <span>string</span><span>;</span> <span>question</span><span>:</span> <span>string</span><span>;</span> <span>};</span> <span>export</span> <span>const</span> <span>useWebSocket</span> <span>=</span> <span>(</span> <span>url</span><span>:</span> <span>string</span><span>,</span> <span>onMessage</span><span>:</span> <span>(</span><span>content</span><span>:</span> <span>string</span><span>,</span> <span>isComplete</span><span>?:</span> <span>boolean</span><span>)</span> <span>=></span> <span>void</span> <span>)</span> <span>=></span> <span>{</span> <span>const</span> <span>[</span><span>socket</span><span>,</span> <span>setSocket</span><span>]</span> <span>=</span> <span>useState</span><span><</span><span>WebSocket</span> <span>|</span> <span>null</span><span>></span><span>(</span><span>null</span><span>);</span> <span>const</span> <span>[</span><span>isConnected</span><span>,</span> <span>setIsConnected</span><span>]</span> <span>=</span> <span>useState</span><span>(</span><span>false</span><span>);</span> <span>const</span> <span>[</span><span>error</span><span>,</span> <span>setError</span><span>]</span> <span>=</span> <span>useState</span><span><</span><span>string</span> <span>|</span> <span>null</span><span>></span><span>(</span><span>null</span><span>);</span> <span>useEffect</span><span>(()</span> <span>=></span> <span>{</span> <span>const</span> <span>ws</span> <span>=</span> <span>new</span> <span>WebSocket</span><span>(</span><span>url</span><span>);</span> <span>ws</span><span>.</span><span>onopen</span> <span>=</span> <span>()</span> <span>=></span> <span>{</span> <span>console</span><span>.</span><span>log</span><span>(</span><span>"</span><span>WebSocket Connected</span><span>"</span><span>);</span> <span>setIsConnected</span><span>(</span><span>true</span><span>);</span> <span>setError</span><span>(</span><span>null</span><span>);</span> <span>};</span> <span>ws</span><span>.</span><span>onmessage</span> <span>=</span> <span>(</span><span>event</span><span>)</span> <span>=></span> <span>{</span> <span>try</span> <span>{</span> <span>const</span> <span>data</span> <span>=</span> <span>JSON</span><span>.</span><span>parse</span><span>(</span><span>event</span><span>.</span><span>data</span><span>);</span> <span>// Handle the streaming response</span> <span>if </span><span>(</span><span>data</span><span>.</span><span>answer</span> <span>===</span> <span>"</span><span>[END]</span><span>"</span><span>)</span> <span>{</span> <span>onMessage</span><span>(</span><span>""</span><span>,</span> <span>true</span><span>);</span> <span>// Signal completion</span> <span>}</span> <span>else</span> <span>{</span> <span>onMessage</span><span>(</span><span>data</span><span>.</span><span>answer</span><span>,</span> <span>false</span><span>);</span> <span>}</span> <span>}</span> <span>catch </span><span>(</span><span>e</span><span>)</span> <span>{</span> <span>console</span><span>.</span><span>error</span><span>(</span><span>"</span><span>Error parsing message:</span><span>"</span><span>,</span> <span>e</span><span>);</span> <span>}</span> <span>};</span> <span>ws</span><span>.</span><span>onclose</span> <span>=</span> <span>()</span> <span>=></span> <span>{</span> <span>console</span><span>.</span><span>log</span><span>(</span><span>"</span><span>WebSocket Disconnected</span><span>"</span><span>);</span> <span>setIsConnected</span><span>(</span><span>false</span><span>);</span> <span>setTimeout</span><span>(()</span> <span>=></span> <span>setSocket</span><span>(</span><span>new</span> <span>WebSocket</span><span>(</span><span>url</span><span>)),</span> <span>3000</span><span>);</span> <span>};</span> <span>ws</span><span>.</span><span>onerror</span> <span>=</span> <span>(</span><span>event</span><span>)</span> <span>=></span> <span>{</span> <span>console</span><span>.</span><span>error</span><span>(</span><span>"</span><span>WebSocket error:</span><span>"</span><span>,</span> <span>event</span><span>);</span> <span>setError</span><span>(</span><span>"</span><span>WebSocket error occurred</span><span>"</span><span>);</span> <span>};</span> <span>setSocket</span><span>(</span><span>ws</span><span>);</span> <span>return </span><span>()</span> <span>=></span> <span>{</span> <span>ws</span><span>.</span><span>close</span><span>();</span> <span>};</span> <span>},</span> <span>[</span><span>url</span><span>,</span> <span>onMessage</span><span>]);</span> <span>const</span> <span>sendMessage</span> <span>=</span> <span>useCallback</span><span>(</span> <span>(</span><span>message</span><span>:</span> <span>string</span><span>,</span> <span>type</span><span>:</span> <span>string</span><span>)</span> <span>=></span> <span>{</span> <span>if </span><span>(</span><span>socket</span> <span>&&</span> <span>isConnected</span><span>)</span> <span>{</span> <span>const</span> <span>payload</span><span>:</span> <span>WebSocketMessage</span> <span>=</span> <span>{</span> <span>type</span><span>:</span> <span>type</span><span>,</span> <span>question</span><span>:</span> <span>message</span><span>,</span> <span>};</span> <span>socket</span><span>.</span><span>send</span><span>(</span><span>JSON</span><span>.</span><span>stringify</span><span>(</span><span>payload</span><span>));</span> <span>}</span> <span>else</span> <span>{</span> <span>console</span><span>.</span><span>log</span><span>(</span><span>"</span><span>Socket not ready:</span><span>"</span><span>,</span> <span>{</span> <span>isConnected</span><span>,</span> <span>socketExists</span><span>:</span> <span>!!</span><span>socket</span><span>,</span> <span>});</span> <span>}</span> <span>},</span> <span>[</span><span>socket</span><span>,</span> <span>isConnected</span><span>]</span> <span>);</span> <span>return</span> <span>{</span> <span>sendMessage</span><span>,</span> <span>isConnected</span><span>,</span> <span>error</span> <span>};</span> <span>};</span>import { useState, useEffect, useCallback } from "react"; type WebSocketMessage = { type: string; question: string; }; export const useWebSocket = ( url: string, onMessage: (content: string, isComplete?: boolean) => void ) => { const [socket, setSocket] = useState<WebSocket | null>(null); const [isConnected, setIsConnected] = useState(false); const [error, setError] = useState<string | null>(null); useEffect(() => { const ws = new WebSocket(url); ws.onopen = () => { console.log("WebSocket Connected"); setIsConnected(true); setError(null); }; ws.onmessage = (event) => { try { const data = JSON.parse(event.data); // Handle the streaming response if (data.answer === "[END]") { onMessage("", true); // Signal completion } else { onMessage(data.answer, false); } } catch (e) { console.error("Error parsing message:", e); } }; ws.onclose = () => { console.log("WebSocket Disconnected"); setIsConnected(false); setTimeout(() => setSocket(new WebSocket(url)), 3000); }; ws.onerror = (event) => { console.error("WebSocket error:", event); setError("WebSocket error occurred"); }; setSocket(ws); return () => { ws.close(); }; }, [url, onMessage]); const sendMessage = useCallback( (message: string, type: string) => { if (socket && isConnected) { const payload: WebSocketMessage = { type: type, question: message, }; socket.send(JSON.stringify(payload)); } else { console.log("Socket not ready:", { isConnected, socketExists: !!socket, }); } }, [socket, isConnected] ); return { sendMessage, isConnected, error }; };
Enter fullscreen mode Exit fullscreen mode
handleSend
function adds a temporary “loading” message to the chat interface. This is a simple but effective UI trick to provide immediate feedback to the user while waiting for the AI bot’s response. Without this, the user might perceive a delay or lack of responsiveness which is unwanted. handleBotMessage
function is designed to handle the streaming responses from the AI bot. It efficiently updates the chat interface by appending new content to the last bot message. The isComplete
flag, triggered by the "[END]"
signal from the backend, ensures that the loading indicator is removed when the bot finishes generating its response. Why useCallback
in handleBotMessage
? Its primary reason is to prevent unnecessary re-renders of components that depend on these functions and it does that by memoizing them so that they only change when their dependencies change.
To wrap up, let’s look at the react-frontend/src/app/ui/chat/ChatContainer.tsx
and react-frontend/src/app/ui/chat/ChatInput.tsx
.
<span>import</span> <span>React</span> <span>from</span> <span>"</span><span>react</span><span>"</span><span>;</span><span>import</span> <span>ChatMessage</span> <span>from</span> <span>"</span><span>$/app/ui/chat/ChatMessage</span><span>"</span><span>;</span><span>import</span> <span>{</span> <span>Message</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/lib/types</span><span>"</span><span>;</span><span>interface</span> <span>ChatContainerProps</span> <span>{</span><span>messages</span><span>:</span> <span>Message</span><span>[];</span><span>}</span><span>const</span> <span>ChatContainer</span><span>:</span> <span>React</span><span>.</span><span>FC</span><span><</span><span>ChatContainerProps</span><span>></span> <span>=</span> <span>({</span> <span>messages</span> <span>})</span> <span>=></span> <span>{</span><span>return </span><span>(</span><span><</span><span>div</span> <span>className</span><span>=</span><span>"flex flex-col h-full overflow-y-auto p-4"</span><span>></span><span><</span><span>div</span> <span>className</span><span>=</span><span>"flex-1 overflow-y-scroll"</span><span>></span><span>{</span><span>messages</span><span>.</span><span>map</span><span>((</span><span>message</span><span>)</span> <span>=></span> <span>(</span><span><</span><span>ChatMessage</span> <span>key</span><span>=</span><span>{</span><span>message</span><span>.</span><span>id</span><span>}</span> <span>message</span><span>=</span><span>{</span><span>message</span><span>}</span> <span>/></span><span>))</span><span>}</span><span></</span><span>div</span><span>></span><span></</span><span>div</span><span>></span><span>);</span><span>};</span><span>export</span> <span>default</span> <span>ChatContainer</span><span>;</span><span>import</span> <span>React</span> <span>from</span> <span>"</span><span>react</span><span>"</span><span>;</span> <span>import</span> <span>ChatMessage</span> <span>from</span> <span>"</span><span>$/app/ui/chat/ChatMessage</span><span>"</span><span>;</span> <span>import</span> <span>{</span> <span>Message</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/lib/types</span><span>"</span><span>;</span> <span>interface</span> <span>ChatContainerProps</span> <span>{</span> <span>messages</span><span>:</span> <span>Message</span><span>[];</span> <span>}</span> <span>const</span> <span>ChatContainer</span><span>:</span> <span>React</span><span>.</span><span>FC</span><span><</span><span>ChatContainerProps</span><span>></span> <span>=</span> <span>({</span> <span>messages</span> <span>})</span> <span>=></span> <span>{</span> <span>return </span><span>(</span> <span><</span><span>div</span> <span>className</span><span>=</span><span>"flex flex-col h-full overflow-y-auto p-4"</span><span>></span> <span><</span><span>div</span> <span>className</span><span>=</span><span>"flex-1 overflow-y-scroll"</span><span>></span> <span>{</span><span>messages</span><span>.</span><span>map</span><span>((</span><span>message</span><span>)</span> <span>=></span> <span>(</span> <span><</span><span>ChatMessage</span> <span>key</span><span>=</span><span>{</span><span>message</span><span>.</span><span>id</span><span>}</span> <span>message</span><span>=</span><span>{</span><span>message</span><span>}</span> <span>/></span> <span>))</span><span>}</span> <span></</span><span>div</span><span>></span> <span></</span><span>div</span><span>></span> <span>);</span> <span>};</span> <span>export</span> <span>default</span> <span>ChatContainer</span><span>;</span>import React from "react"; import ChatMessage from "$/app/ui/chat/ChatMessage"; import { Message } from "$/app/lib/types"; interface ChatContainerProps { messages: Message[]; } const ChatContainer: React.FC<ChatContainerProps> = ({ messages }) => { return ( <div className="flex flex-col h-full overflow-y-auto p-4"> <div className="flex-1 overflow-y-scroll"> {messages.map((message) => ( <ChatMessage key={message.id} message={message} /> ))} </div> </div> ); }; export default ChatContainer;
Enter fullscreen mode Exit fullscreen mode
The ChatContainer
component renders the list of ChatMessage
components. It’s a straightforward component that iterates over the messages and displays each one.
<span>import</span> <span>{</span> <span>marked</span> <span>}</span> <span>from</span> <span>"</span><span>marked</span><span>"</span><span>;</span><span>import</span> <span>{</span> <span>Message</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/lib/types</span><span>"</span><span>;</span><span>import</span> <span>{</span> <span>ThinkingAnimation</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/ui/reusables</span><span>"</span><span>;</span><span>import</span> <span>"</span><span>katex/dist/katex.min.css</span><span>"</span><span>;</span><span>import</span> <span>katex</span> <span>from</span> <span>"</span><span>katex</span><span>"</span><span>;</span><span>import</span> <span>hljs</span> <span>from</span> <span>"</span><span>highlight.js</span><span>"</span><span>;</span><span>import</span> <span>"</span><span>highlight.js/styles/base16/horizon-dark.min.css</span><span>"</span><span>;</span><span>// Helper function to preprocess Unicode characters</span><span>const</span> <span>preprocessUnicode</span> <span>=</span> <span>(</span><span>text</span><span>:</span> <span>string</span><span>):</span> <span>string</span> <span>=></span> <span>{</span><span>// Handle escaped Unicode sequences</span><span>const</span> <span>unescaped</span> <span>=</span> <span>text</span><span>.</span><span>replace</span><span>(</span><span>/</span><span>\\</span><span>u</span><span>\{?([</span><span>a-fA-F0-9</span><span>]{4,6})\}?</span><span>/g</span><span>,</span><span>(</span><span>_</span><span>,</span> <span>codePoint</span><span>)</span> <span>=></span> <span>String</span><span>.</span><span>fromCodePoint</span><span>(</span><span>parseInt</span><span>(</span><span>codePoint</span><span>,</span> <span>16</span><span>))</span><span>);</span><span>return</span> <span>unescaped</span><span>;</span><span>};</span><span>marked</span><span>.</span><span>setOptions</span><span>({</span><span>gfm</span><span>:</span> <span>true</span><span>,</span><span>breaks</span><span>:</span> <span>true</span><span>,</span><span>pedantic</span><span>:</span> <span>false</span><span>,</span><span>});</span><span>// Custom renderer to handle streaming better</span><span>const</span> <span>renderer</span> <span>=</span> <span>new</span> <span>marked</span><span>.</span><span>Renderer</span><span>();</span><span>// Add math support</span><span>const</span> <span>mathRenderer</span> <span>=</span> <span>{</span><span>name</span><span>:</span> <span>"</span><span>math</span><span>"</span><span>,</span><span>level</span><span>:</span> <span>"</span><span>inline</span><span>"</span><span>,</span><span>start</span><span>(</span><span>src</span><span>:</span> <span>string</span><span>)</span> <span>{</span><span>return</span> <span>src</span><span>.</span><span>match</span><span>(</span><span>/</span><span>\$</span><span>/</span><span>)?.</span><span>index</span><span>;</span><span>},</span><span>tokenizer</span><span>(</span><span>src</span><span>:</span> <span>string</span><span>)</span> <span>{</span><span>const</span> <span>match</span> <span>=</span> <span>src</span><span>.</span><span>match</span><span>(</span><span>/^</span><span>\$\$([^</span><span>$</span><span>\n]</span><span>+</span><span>?)\$\$</span><span>|^</span><span>\$([^</span><span>$</span><span>\n]</span><span>+</span><span>?)\$</span><span>/</span><span>);</span><span>if </span><span>(</span><span>match</span><span>)</span> <span>{</span><span>const</span> <span>isDisplay</span> <span>=</span> <span>match</span><span>[</span><span>0</span><span>].</span><span>startsWith</span><span>(</span><span>"</span><span>$$</span><span>"</span><span>);</span><span>return</span> <span>{</span><span>type</span><span>:</span> <span>"</span><span>math</span><span>"</span><span>,</span><span>raw</span><span>:</span> <span>match</span><span>[</span><span>0</span><span>],</span><span>text</span><span>:</span> <span>(</span><span>isDisplay</span> <span>?</span> <span>match</span><span>[</span><span>1</span><span>]</span> <span>:</span> <span>match</span><span>[</span><span>2</span><span>]).</span><span>trim</span><span>(),</span><span>isDisplay</span><span>,</span><span>};</span><span>}</span><span>},</span><span>renderer</span><span>(</span><span>token</span><span>:</span> <span>any</span><span>)</span> <span>{</span><span>try</span> <span>{</span><span>return</span> <span>katex</span><span>.</span><span>renderToString</span><span>(</span><span>token</span><span>.</span><span>text</span><span>,</span> <span>{</span><span>throwOnError</span><span>:</span> <span>false</span><span>,</span><span>displayMode</span><span>:</span> <span>token</span><span>.</span><span>isDisplay</span><span>,</span><span>});</span><span>}</span> <span>catch </span><span>(</span><span>err</span><span>)</span> <span>{</span><span>return</span> <span>token</span><span>.</span><span>raw</span><span>;</span><span>}</span><span>},</span><span>};</span><span>const</span> <span>codeHighlightExtension</span> <span>=</span> <span>{</span><span>name</span><span>:</span> <span>"</span><span>code-highlight</span><span>"</span><span>,</span><span>level</span><span>:</span> <span>"</span><span>block</span><span>"</span><span>,</span><span>start</span><span>(</span><span>src</span><span>:</span> <span>string</span><span>)</span> <span>{</span><span>return</span> <span>src</span><span>.</span><span>match</span><span>(</span><span>/^``</span><span>` </span><span>{</span><span>%</span> <span>endraw</span> <span>%</span><span>}</span><span>/</span><span>)?</span><span>.index</span><span>; </span> <span>},</span><span>tokenizer</span><span>(</span><span>src</span><span>:</span> <span>string</span><span>)</span> <span>{</span><span>const</span> <span>match</span> <span>=</span> <span>src</span><span>.</span><span>match</span><span>(</span><span>/</span><span>^ </span><span>{</span><span>%</span> <span>raw</span> <span>%</span><span>}</span><span>```(\S*)\n([\s\S]*?)```</span><span>{</span><span>%</span> <span>endraw</span> <span>%</span><span>}</span><span>/</span><span>)</span><span>; </span> <span>if </span><span>(</span><span>match</span><span>)</span> <span>{</span><span>return</span> <span>{</span><span>type</span><span>:</span> <span>"</span><span>code-highlight</span><span>"</span><span>,</span><span>raw</span><span>:</span> <span>match</span><span>[</span><span>0</span><span>],</span><span>lang</span><span>:</span> <span>match</span><span>[</span><span>1</span><span>],</span><span>text</span><span>:</span> <span>match</span><span>[</span><span>2</span><span>].</span><span>trim</span><span>(),</span><span>};</span><span>}</span><span>},</span><span>renderer</span><span>(</span><span>token</span><span>:</span> <span>any</span><span>)</span> <span>{</span><span>if </span><span>(</span><span>token</span><span>.</span><span>lang</span> <span>&&</span> <span>hljs</span><span>.</span><span>getLanguage</span><span>(</span><span>token</span><span>.</span><span>lang</span><span>))</span> <span>{</span><span>try</span> <span>{</span><span>const</span> <span>highlighted</span> <span>=</span> <span>hljs</span><span>.</span><span>highlight</span><span>(</span><span>token</span><span>.</span><span>text</span><span>,</span> <span>{</span><span>language</span><span>:</span> <span>token</span><span>.</span><span>lang</span><span>,</span><span>ignoreIllegals</span><span>:</span> <span>true</span><span>,</span><span>}).</span><span>value</span><span>;</span><span>return</span> <span>{</span><span>%</span> <span>raw</span> <span>%</span><span>}</span><span>`<pre><code class="hljs language-</span><span>${</span><span>token</span><span>.</span><span>lang</span><span>}</span><span>"></span><span>${</span><span>highlighted</span><span>}</span><span></code></pre>`</span><span>{</span><span>%</span> <span>endraw</span> <span>%</span><span>};</span><span>}</span> <span>catch </span><span>(</span><span>err</span><span>)</span> <span>{</span><span>console</span><span>.</span><span>error</span><span>(</span><span>"</span><span>Highlight.js error:</span><span>"</span><span>,</span> <span>err</span><span>);</span><span>}</span><span>}</span><span>return</span> <span>{</span><span>%</span> <span>raw</span> <span>%</span><span>}</span><span>`<pre><code></span><span>${</span><span>token</span><span>.</span><span>text</span><span>}</span><span></code></pre>`</span><span>{</span><span>%</span> <span>endraw</span> <span>%</span><span>};</span><span>},</span><span>};</span><span>marked</span><span>.</span><span>use</span><span>({</span><span>extensions</span><span>:</span> <span>[</span><span>mathRenderer</span><span>,</span> <span>codeHighlightExtension</span><span>],</span><span>});</span><span>interface</span> <span>ChatMessageProps</span> <span>{</span><span>message</span><span>:</span> <span>Message</span><span>;</span><span>}</span><span>export</span> <span>default</span> <span>function</span> <span>ChatMessage</span><span>({</span> <span>message</span> <span>}:</span> <span>ChatMessageProps</span><span>)</span> <span>{</span><span>return </span><span>(</span><span><</span><span>div</span><span>className</span><span>=</span><span>{</span><span>{</span><span>%</span> <span>raw</span> <span>%</span><span>}</span><span>`flex </span><span>${</span><span>message</span><span>.</span><span>sender</span> <span>===</span> <span>"</span><span>user</span><span>"</span> <span>?</span> <span>"</span><span>justify-end</span><span>"</span> <span>:</span> <span>"</span><span>justify-start</span><span>"</span><span>}</span><span> mb-4`</span><span>{</span><span>%</span> <span>endraw</span> <span>%</span><span>}</span><span>}</span><span>></span><span><</span><span>div</span><span>className</span><span>=</span><span>{</span><span>{</span><span>%</span> <span>raw</span> <span>%</span><span>}</span><span>`max-w-[80%] p-3 rounded-lg </span><span>${</span><span>message</span><span>.</span><span>sender</span> <span>===</span> <span>"</span><span>user</span><span>"</span><span>?</span> <span>"</span><span>bg-[#f5f5f5] text-[#171717] dark:bg-[#1a1a1a] dark:text-[#ededed]</span><span>"</span><span>:</span> <span>"</span><span>bg-[#fafafa] text-[#171717] dark:bg-[#141414] dark:text-[#ededed]</span><span>"</span><span>}</span><span>`</span><span>{</span><span>%</span> <span>endraw</span> <span>%</span><span>}</span><span>}</span><span>></span><span>{</span><span>message</span><span>.</span><span>loading</span> <span>?</span> <span>(</span><span><</span><span>ThinkingAnimation</span> <span>/></span><span>)</span> <span>:</span> <span>(</span><span><</span><span>div</span> <span>className</span><span>=</span><span>"prose article-content max-w-none animate-fade-in"</span><span>></span><span>{</span><span>message</span><span>.</span><span>sender</span> <span>===</span> <span>"</span><span>bot</span><span>"</span> <span>?</span> <span>(</span><span><</span><span>span</span><span>dangerouslySetInnerHTML</span><span>=</span><span>{</span><span>{</span><span>__html</span><span>:</span> <span>marked</span><span>(</span><span>preprocessUnicode</span><span>(</span><span>message</span><span>.</span><span>text</span><span>),</span> <span>{</span> <span>renderer</span> <span>}),</span><span>}</span><span>}</span><span>/></span><span>)</span> <span>:</span> <span>(</span><span><</span><span>span</span><span>dangerouslySetInnerHTML</span><span>=</span><span>{</span><span>{</span><span>__html</span><span>:</span> <span>marked</span><span>(</span><span>preprocessUnicode</span><span>(</span><span>message</span><span>.</span><span>text</span><span>),</span> <span>{</span> <span>renderer</span> <span>}),</span><span>}</span><span>}</span><span>/></span><span>)</span><span>}</span><span></</span><span>div</span><span>></span><span>)</span><span>}</span><span></</span><span>div</span><span>></span><span></</span><span>div</span><span>></span><span>);</span><span>}</span><span>{</span><span>%</span> <span>raw</span> <span>%</span><span>}</span><span>import</span> <span>{</span> <span>marked</span> <span>}</span> <span>from</span> <span>"</span><span>marked</span><span>"</span><span>;</span> <span>import</span> <span>{</span> <span>Message</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/lib/types</span><span>"</span><span>;</span> <span>import</span> <span>{</span> <span>ThinkingAnimation</span> <span>}</span> <span>from</span> <span>"</span><span>$/app/ui/reusables</span><span>"</span><span>;</span> <span>import</span> <span>"</span><span>katex/dist/katex.min.css</span><span>"</span><span>;</span> <span>import</span> <span>katex</span> <span>from</span> <span>"</span><span>katex</span><span>"</span><span>;</span> <span>import</span> <span>hljs</span> <span>from</span> <span>"</span><span>highlight.js</span><span>"</span><span>;</span> <span>import</span> <span>"</span><span>highlight.js/styles/base16/horizon-dark.min.css</span><span>"</span><span>;</span> <span>// Helper function to preprocess Unicode characters</span> <span>const</span> <span>preprocessUnicode</span> <span>=</span> <span>(</span><span>text</span><span>:</span> <span>string</span><span>):</span> <span>string</span> <span>=></span> <span>{</span> <span>// Handle escaped Unicode sequences</span> <span>const</span> <span>unescaped</span> <span>=</span> <span>text</span><span>.</span><span>replace</span><span>(</span> <span>/</span><span>\\</span><span>u</span><span>\{?([</span><span>a-fA-F0-9</span><span>]{4,6})\}?</span><span>/g</span><span>,</span> <span>(</span><span>_</span><span>,</span> <span>codePoint</span><span>)</span> <span>=></span> <span>String</span><span>.</span><span>fromCodePoint</span><span>(</span><span>parseInt</span><span>(</span><span>codePoint</span><span>,</span> <span>16</span><span>))</span> <span>);</span> <span>return</span> <span>unescaped</span><span>;</span> <span>};</span> <span>marked</span><span>.</span><span>setOptions</span><span>({</span> <span>gfm</span><span>:</span> <span>true</span><span>,</span> <span>breaks</span><span>:</span> <span>true</span><span>,</span> <span>pedantic</span><span>:</span> <span>false</span><span>,</span> <span>});</span> <span>// Custom renderer to handle streaming better</span> <span>const</span> <span>renderer</span> <span>=</span> <span>new</span> <span>marked</span><span>.</span><span>Renderer</span><span>();</span> <span>// Add math support</span> <span>const</span> <span>mathRenderer</span> <span>=</span> <span>{</span> <span>name</span><span>:</span> <span>"</span><span>math</span><span>"</span><span>,</span> <span>level</span><span>:</span> <span>"</span><span>inline</span><span>"</span><span>,</span> <span>start</span><span>(</span><span>src</span><span>:</span> <span>string</span><span>)</span> <span>{</span> <span>return</span> <span>src</span><span>.</span><span>match</span><span>(</span><span>/</span><span>\$</span><span>/</span><span>)?.</span><span>index</span><span>;</span> <span>},</span> <span>tokenizer</span><span>(</span><span>src</span><span>:</span> <span>string</span><span>)</span> <span>{</span> <span>const</span> <span>match</span> <span>=</span> <span>src</span><span>.</span><span>match</span><span>(</span><span>/^</span><span>\$\$([^</span><span>$</span><span>\n]</span><span>+</span><span>?)\$\$</span><span>|^</span><span>\$([^</span><span>$</span><span>\n]</span><span>+</span><span>?)\$</span><span>/</span><span>);</span> <span>if </span><span>(</span><span>match</span><span>)</span> <span>{</span> <span>const</span> <span>isDisplay</span> <span>=</span> <span>match</span><span>[</span><span>0</span><span>].</span><span>startsWith</span><span>(</span><span>"</span><span>$$</span><span>"</span><span>);</span> <span>return</span> <span>{</span> <span>type</span><span>:</span> <span>"</span><span>math</span><span>"</span><span>,</span> <span>raw</span><span>:</span> <span>match</span><span>[</span><span>0</span><span>],</span> <span>text</span><span>:</span> <span>(</span><span>isDisplay</span> <span>?</span> <span>match</span><span>[</span><span>1</span><span>]</span> <span>:</span> <span>match</span><span>[</span><span>2</span><span>]).</span><span>trim</span><span>(),</span> <span>isDisplay</span><span>,</span> <span>};</span> <span>}</span> <span>},</span> <span>renderer</span><span>(</span><span>token</span><span>:</span> <span>any</span><span>)</span> <span>{</span> <span>try</span> <span>{</span> <span>return</span> <span>katex</span><span>.</span><span>renderToString</span><span>(</span><span>token</span><span>.</span><span>text</span><span>,</span> <span>{</span> <span>throwOnError</span><span>:</span> <span>false</span><span>,</span> <span>displayMode</span><span>:</span> <span>token</span><span>.</span><span>isDisplay</span><span>,</span> <span>});</span> <span>}</span> <span>catch </span><span>(</span><span>err</span><span>)</span> <span>{</span> <span>return</span> <span>token</span><span>.</span><span>raw</span><span>;</span> <span>}</span> <span>},</span> <span>};</span> <span>const</span> <span>codeHighlightExtension</span> <span>=</span> <span>{</span> <span>name</span><span>:</span> <span>"</span><span>code-highlight</span><span>"</span><span>,</span> <span>level</span><span>:</span> <span>"</span><span>block</span><span>"</span><span>,</span> <span>start</span><span>(</span><span>src</span><span>:</span> <span>string</span><span>)</span> <span>{</span> <span>return</span> <span>src</span><span>.</span><span>match</span><span>(</span><span>/^``</span><span>` </span><span>{</span><span>%</span> <span>endraw</span> <span>%</span><span>}</span> <span>/</span><span>)?</span><span>.index</span><span>; </span> <span>},</span> <span>tokenizer</span><span>(</span><span>src</span><span>:</span> <span>string</span><span>)</span> <span>{</span> <span>const</span> <span>match</span> <span>=</span> <span>src</span><span>.</span><span>match</span><span>(</span><span>/</span><span>^ </span><span>{</span><span>%</span> <span>raw</span> <span>%</span><span>}</span> <span>```(\S*)\n([\s\S]*?)```</span> <span>{</span><span>%</span> <span>endraw</span> <span>%</span><span>}</span> <span>/</span><span>)</span><span>; </span> <span>if </span><span>(</span><span>match</span><span>)</span> <span>{</span> <span>return</span> <span>{</span> <span>type</span><span>:</span> <span>"</span><span>code-highlight</span><span>"</span><span>,</span> <span>raw</span><span>:</span> <span>match</span><span>[</span><span>0</span><span>],</span> <span>lang</span><span>:</span> <span>match</span><span>[</span><span>1</span><span>],</span> <span>text</span><span>:</span> <span>match</span><span>[</span><span>2</span><span>].</span><span>trim</span><span>(),</span> <span>};</span> <span>}</span> <span>},</span> <span>renderer</span><span>(</span><span>token</span><span>:</span> <span>any</span><span>)</span> <span>{</span> <span>if </span><span>(</span><span>token</span><span>.</span><span>lang</span> <span>&&</span> <span>hljs</span><span>.</span><span>getLanguage</span><span>(</span><span>token</span><span>.</span><span>lang</span><span>))</span> <span>{</span> <span>try</span> <span>{</span> <span>const</span> <span>highlighted</span> <span>=</span> <span>hljs</span><span>.</span><span>highlight</span><span>(</span><span>token</span><span>.</span><span>text</span><span>,</span> <span>{</span> <span>language</span><span>:</span> <span>token</span><span>.</span><span>lang</span><span>,</span> <span>ignoreIllegals</span><span>:</span> <span>true</span><span>,</span> <span>}).</span><span>value</span><span>;</span> <span>return</span> <span>{</span><span>%</span> <span>raw</span> <span>%</span><span>}</span><span>`<pre><code class="hljs language-</span><span>${</span><span>token</span><span>.</span><span>lang</span><span>}</span><span>"></span><span>${</span><span>highlighted</span><span>}</span><span></code></pre>`</span><span>{</span><span>%</span> <span>endraw</span> <span>%</span><span>};</span> <span>}</span> <span>catch </span><span>(</span><span>err</span><span>)</span> <span>{</span> <span>console</span><span>.</span><span>error</span><span>(</span><span>"</span><span>Highlight.js error:</span><span>"</span><span>,</span> <span>err</span><span>);</span> <span>}</span> <span>}</span> <span>return</span> <span>{</span><span>%</span> <span>raw</span> <span>%</span><span>}</span><span>`<pre><code></span><span>${</span><span>token</span><span>.</span><span>text</span><span>}</span><span></code></pre>`</span><span>{</span><span>%</span> <span>endraw</span> <span>%</span><span>};</span> <span>},</span> <span>};</span> <span>marked</span><span>.</span><span>use</span><span>({</span> <span>extensions</span><span>:</span> <span>[</span><span>mathRenderer</span><span>,</span> <span>codeHighlightExtension</span><span>],</span> <span>});</span> <span>interface</span> <span>ChatMessageProps</span> <span>{</span> <span>message</span><span>:</span> <span>Message</span><span>;</span> <span>}</span> <span>export</span> <span>default</span> <span>function</span> <span>ChatMessage</span><span>({</span> <span>message</span> <span>}:</span> <span>ChatMessageProps</span><span>)</span> <span>{</span> <span>return </span><span>(</span> <span><</span><span>div</span> <span>className</span><span>=</span><span>{</span><span>{</span><span>%</span> <span>raw</span> <span>%</span><span>}</span><span>`flex </span><span>${</span> <span>message</span><span>.</span><span>sender</span> <span>===</span> <span>"</span><span>user</span><span>"</span> <span>?</span> <span>"</span><span>justify-end</span><span>"</span> <span>:</span> <span>"</span><span>justify-start</span><span>"</span> <span>}</span><span> mb-4`</span><span>{</span><span>%</span> <span>endraw</span> <span>%</span><span>}</span><span>}</span> <span>></span> <span><</span><span>div</span> <span>className</span><span>=</span><span>{</span><span>{</span><span>%</span> <span>raw</span> <span>%</span><span>}</span><span>`max-w-[80%] p-3 rounded-lg </span><span>${</span> <span>message</span><span>.</span><span>sender</span> <span>===</span> <span>"</span><span>user</span><span>"</span> <span>?</span> <span>"</span><span>bg-[#f5f5f5] text-[#171717] dark:bg-[#1a1a1a] dark:text-[#ededed]</span><span>"</span> <span>:</span> <span>"</span><span>bg-[#fafafa] text-[#171717] dark:bg-[#141414] dark:text-[#ededed]</span><span>"</span> <span>}</span><span>`</span><span>{</span><span>%</span> <span>endraw</span> <span>%</span><span>}</span><span>}</span> <span>></span> <span>{</span><span>message</span><span>.</span><span>loading</span> <span>?</span> <span>(</span> <span><</span><span>ThinkingAnimation</span> <span>/></span> <span>)</span> <span>:</span> <span>(</span> <span><</span><span>div</span> <span>className</span><span>=</span><span>"prose article-content max-w-none animate-fade-in"</span><span>></span> <span>{</span><span>message</span><span>.</span><span>sender</span> <span>===</span> <span>"</span><span>bot</span><span>"</span> <span>?</span> <span>(</span> <span><</span><span>span</span> <span>dangerouslySetInnerHTML</span><span>=</span><span>{</span><span>{</span> <span>__html</span><span>:</span> <span>marked</span><span>(</span><span>preprocessUnicode</span><span>(</span><span>message</span><span>.</span><span>text</span><span>),</span> <span>{</span> <span>renderer</span> <span>}),</span> <span>}</span><span>}</span> <span>/></span> <span>)</span> <span>:</span> <span>(</span> <span><</span><span>span</span> <span>dangerouslySetInnerHTML</span><span>=</span><span>{</span><span>{</span> <span>__html</span><span>:</span> <span>marked</span><span>(</span><span>preprocessUnicode</span><span>(</span><span>message</span><span>.</span><span>text</span><span>),</span> <span>{</span> <span>renderer</span> <span>}),</span> <span>}</span><span>}</span> <span>/></span> <span>)</span><span>}</span> <span></</span><span>div</span><span>></span> <span>)</span><span>}</span> <span></</span><span>div</span><span>></span> <span></</span><span>div</span><span>></span> <span>);</span> <span>}</span> <span>{</span><span>%</span> <span>raw</span> <span>%</span><span>}</span>import { marked } from "marked"; import { Message } from "$/app/lib/types"; import { ThinkingAnimation } from "$/app/ui/reusables"; import "katex/dist/katex.min.css"; import katex from "katex"; import hljs from "highlight.js"; import "highlight.js/styles/base16/horizon-dark.min.css"; // Helper function to preprocess Unicode characters const preprocessUnicode = (text: string): string => { // Handle escaped Unicode sequences const unescaped = text.replace( /\\u\{?([a-fA-F0-9]{4,6})\}?/g, (_, codePoint) => String.fromCodePoint(parseInt(codePoint, 16)) ); return unescaped; }; marked.setOptions({ gfm: true, breaks: true, pedantic: false, }); // Custom renderer to handle streaming better const renderer = new marked.Renderer(); // Add math support const mathRenderer = { name: "math", level: "inline", start(src: string) { return src.match(/\$/)?.index; }, tokenizer(src: string) { const match = src.match(/^\$\$([^$\n]+?)\$\$|^\$([^$\n]+?)\$/); if (match) { const isDisplay = match[0].startsWith("$$"); return { type: "math", raw: match[0], text: (isDisplay ? match[1] : match[2]).trim(), isDisplay, }; } }, renderer(token: any) { try { return katex.renderToString(token.text, { throwOnError: false, displayMode: token.isDisplay, }); } catch (err) { return token.raw; } }, }; const codeHighlightExtension = { name: "code-highlight", level: "block", start(src: string) { return src.match(/^``` {% endraw %} /)?.index; }, tokenizer(src: string) { const match = src.match(/^ {% raw %} ```(\S*)\n([\s\S]*?)``` {% endraw %} /); if (match) { return { type: "code-highlight", raw: match[0], lang: match[1], text: match[2].trim(), }; } }, renderer(token: any) { if (token.lang && hljs.getLanguage(token.lang)) { try { const highlighted = hljs.highlight(token.text, { language: token.lang, ignoreIllegals: true, }).value; return {% raw %}`<pre><code class="hljs language-${token.lang}">${highlighted}</code></pre>`{% endraw %}; } catch (err) { console.error("Highlight.js error:", err); } } return {% raw %}`<pre><code>${token.text}</code></pre>`{% endraw %}; }, }; marked.use({ extensions: [mathRenderer, codeHighlightExtension], }); interface ChatMessageProps { message: Message; } export default function ChatMessage({ message }: ChatMessageProps) { return ( <div className={{% raw %}`flex ${ message.sender === "user" ? "justify-end" : "justify-start" } mb-4`{% endraw %}} > <div className={{% raw %}`max-w-[80%] p-3 rounded-lg ${ message.sender === "user" ? "bg-[#f5f5f5] text-[#171717] dark:bg-[#1a1a1a] dark:text-[#ededed]" : "bg-[#fafafa] text-[#171717] dark:bg-[#141414] dark:text-[#ededed]" }`{% endraw %}} > {message.loading ? ( <ThinkingAnimation /> ) : ( <div className="prose article-content max-w-none animate-fade-in"> {message.sender === "bot" ? ( <span dangerouslySetInnerHTML={{ __html: marked(preprocessUnicode(message.text), { renderer }), }} /> ) : ( <span dangerouslySetInnerHTML={{ __html: marked(preprocessUnicode(message.text), { renderer }), }} /> )} </div> )} </div> </div> ); } {% raw %}
Enter fullscreen mode Exit fullscreen mode
The ChatMessage
component is responsible for rendering individual chat messages. The key feature of this component is its use of marked.js
to render Markdown content, with extensions for syntax highlighting (using highlight.js
) and math rendering (using KaTeX
). The component also includes logic to display a “thinking” animation while the bot is generating a response.
Lastly, let’s look at react-frontend/src/app/ui/chat/ChatInput.tsx
.
It provides the user interface for entering and sending messages. It includes a text input, controls for toggling “Masked” and “Auto” modes, and a file upload button (you must select the masked
mode to upload). The component was inspired by X’s Grok 3, .
An important to note is that in handleFileUpload
, I used a NextJS action (uploadFiles
) to communicate with the server for files upload:
typescript"use server";import { BASE_API_URL } from "$/app/lib/constants";export async function uploadFiles(formData: FormData) {try {const response = await fetch(`${BASE_API_URL}/api/extract`, {method: "POST",body: formData,});if (!response.ok) {throw new Error("Upload failed");}return await response.json();} catch (error) {console.error("Upload error:", error);throw error;}}typescript "use server"; import { BASE_API_URL } from "$/app/lib/constants"; export async function uploadFiles(formData: FormData) { try { const response = await fetch(`${BASE_API_URL}/api/extract`, { method: "POST", body: formData, }); if (!response.ok) { throw new Error("Upload failed"); } return await response.json(); } catch (error) { console.error("Upload error:", error); throw error; } }typescript "use server"; import { BASE_API_URL } from "$/app/lib/constants"; export async function uploadFiles(formData: FormData) { try { const response = await fetch(`${BASE_API_URL}/api/extract`, { method: "POST", body: formData, }); if (!response.ok) { throw new Error("Upload failed"); } return await response.json(); } catch (error) { console.error("Upload error:", error); throw error; } }
Enter fullscreen mode Exit fullscreen mode
The key reason for using a server action is to avoid CORS errors such as Cross-Origin Request Blocked: The Same Origin Policy disallows reading the remote resource at ...
. By executing the upload logic on the server, we bypass the browser’s same-origin policy and can communicate with the backend API without requiring CORS configuration.
I am sorry for the long article. I didn’t want to split so I can be forced to just finish it at a go. There are still a couple of cleanups to do but I will leave those to you. Bye for now.
Outro
Enjoyed this article? I’m a Software Engineer, Technical Writer, and Technical Support Engineer actively seeking new opportunities, particularly in areas related to web security, finance, healthcare, and education. If you think my expertise aligns with your team’s needs, let’s chat! You can find me on LinkedIn and X. I am also an email away.
If you found this article valuable, consider sharing it with your network to help spread the knowledge!
原文链接:How to Build a Dual-LLM Chat Application with Next.js, Python, and WebSocket Streaming
暂无评论内容