Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

minimal_inference_server.py

A simplified LLM inference server demonstrating core architecture

This script implements a minimal inference server showing the key components: request handling, batching, and token generation.

What It Does

  1. Creates a simple request queue
  2. Implements basic batching logic
  3. Simulates the prefill/decode loop
  4. Demonstrates streaming output

Architecture

┌─────────────────────────────────────────────────┐
│           Minimal Inference Server               │
│                                                  │
│  Request Queue ──► Batcher ──► Model ──► Output │
│                                                  │
│  Components:                                     │
│  - RequestQueue: FIFO queue for incoming prompts│
│  - SimpleBatcher: Groups requests for GPU        │
│  - MockModel: Simulates forward pass            │
│  - Generator: Token-by-token output loop        │
└─────────────────────────────────────────────────┘

Run It

python tutorial/part3-inference/chapter08-server-anatomy/scripts/minimal_inference_server.py

Key Learning Points

Request Lifecycle:

# 1. Request arrives
request = Request(prompt="Hello, world!")

# 2. Tokenize
tokens = tokenizer.encode(request.prompt)

# 3. Add to queue
queue.add(request)

# 4. Batch processing
batch = batcher.get_next_batch()

# 5. Prefill (process prompt)
kv_cache = model.prefill(batch)

# 6. Decode (generate tokens)
while not done:
    next_token = model.decode(kv_cache)
    yield next_token

What This Demonstrates

  • Separation of concerns: Each component has a single responsibility
  • Queue management: Requests are processed fairly
  • Batching strategy: Multiple requests share GPU
  • Two-phase inference: Prefill then decode

What’s Missing (Real Systems)

  • KV cache management (Chapter 9)
  • CUDA graph optimization (Chapter 10)
  • Speculative decoding (Chapter 11)
  • Tensor parallelism for large models
  • Production error handling

Source Code

#!/usr/bin/env python3
"""
Minimal LLM Inference Server

This script demonstrates the core components of an inference server:
- Request management
- Simple batching
- Token generation loop

This is educational, not production-ready. Real servers like vLLM and
SGLang have much more sophisticated implementations.

Usage:
    python minimal_inference_server.py
    python minimal_inference_server.py --num-requests 10
"""

import argparse
import asyncio
import time
from dataclasses import dataclass, field
from typing import List, Optional, AsyncIterator
from collections import deque
import random


@dataclass
class GenerateRequest:
    """A request to generate text."""
    id: int
    prompt: str
    prompt_tokens: List[int]
    max_tokens: int = 50
    temperature: float = 1.0
    created_at: float = field(default_factory=time.time)

    # Tracking
    generated_tokens: List[int] = field(default_factory=list)
    is_finished: bool = False
    prefill_done: bool = False


@dataclass
class Batch:
    """A batch of requests to process together."""
    requests: List[GenerateRequest]
    is_prefill: bool  # True for prefill, False for decode


class SimpleTokenizer:
    """
    A simplified tokenizer for demonstration.

    Real tokenizers (like SentencePiece or tiktoken) are more complex.
    """

    def __init__(self, vocab_size: int = 1000):
        self.vocab_size = vocab_size
        # Simple word-based tokenization
        self.token_to_id = {"<pad>": 0, "<eos>": 1, "<unk>": 2}
        self.id_to_token = {0: "<pad>", 1: "<eos>", 2: "<unk>"}

    def encode(self, text: str) -> List[int]:
        """Convert text to token IDs."""
        # Simplified: assign random IDs to words
        words = text.lower().split()
        tokens = []
        for word in words:
            # Hash word to get consistent token ID
            token_id = hash(word) % (self.vocab_size - 3) + 3
            tokens.append(token_id)
        return tokens

    def decode(self, token_ids: List[int]) -> str:
        """Convert token IDs back to text."""
        # Simplified: just return placeholder
        return f"[Generated {len(token_ids)} tokens]"


class SimpleModelRunner:
    """
    A simplified model runner for demonstration.

    Real model runners load actual neural networks and run GPU inference.
    """

    def __init__(self, vocab_size: int = 1000, latency_ms: float = 10):
        self.vocab_size = vocab_size
        self.latency_ms = latency_ms

    async def prefill(self, request: GenerateRequest) -> int:
        """
        Process prompt and return first generated token.

        Real prefill:
        1. Run all prompt tokens through model in parallel
        2. Build KV cache for all tokens
        3. Sample first output token
        """
        # Simulate compute time (proportional to prompt length)
        prompt_len = len(request.prompt_tokens)
        await asyncio.sleep(self.latency_ms * prompt_len / 100)

        # "Generate" first token
        first_token = random.randint(3, self.vocab_size - 1)
        return first_token

    async def decode(self, batch: List[GenerateRequest]) -> List[int]:
        """
        Generate next token for each request in batch.

        Real decode:
        1. Run single token through model for each request
        2. Update KV cache with new KV pairs
        3. Sample next token for each request
        """
        # Simulate compute time (roughly constant per batch)
        await asyncio.sleep(self.latency_ms)

        # "Generate" next tokens
        next_tokens = []
        for req in batch:
            # 10% chance of generating EOS
            if random.random() < 0.1:
                next_tokens.append(1)  # EOS
            else:
                next_tokens.append(random.randint(3, self.vocab_size - 1))
        return next_tokens


class Scheduler:
    """
    Manages request queue and batching decisions.

    Key responsibilities:
    1. Accept new requests
    2. Decide which requests to process together
    3. Manage prefill vs decode scheduling
    """

    def __init__(self, max_batch_size: int = 4):
        self.max_batch_size = max_batch_size
        self.waiting_queue: deque = deque()  # Requests waiting for prefill
        self.running_batch: List[GenerateRequest] = []  # Requests in decode phase
        self.completed: List[GenerateRequest] = []

    def add_request(self, request: GenerateRequest):
        """Add a new request to the waiting queue."""
        self.waiting_queue.append(request)
        print(f"[Scheduler] Added request {request.id} to queue "
              f"(queue size: {len(self.waiting_queue)})")

    def get_next_batch(self) -> Optional[Batch]:
        """
        Decide what to process next.

        Strategy (simplified):
        1. If we have requests waiting AND room in running batch, do prefill
        2. If running batch has requests, do decode
        """
        # Check for finished requests first
        self.running_batch = [r for r in self.running_batch if not r.is_finished]

        # Prefill new requests if we have capacity
        while (self.waiting_queue and
               len(self.running_batch) < self.max_batch_size):
            request = self.waiting_queue.popleft()
            return Batch(requests=[request], is_prefill=True)

        # Decode existing requests
        if self.running_batch:
            return Batch(requests=self.running_batch, is_prefill=False)

        return None

    def process_prefill_result(self, request: GenerateRequest, token: int):
        """Handle result from prefill."""
        request.prefill_done = True
        request.generated_tokens.append(token)
        self.running_batch.append(request)
        print(f"[Scheduler] Request {request.id} finished prefill, "
              f"added to running batch (size: {len(self.running_batch)})")

    def process_decode_result(self, request: GenerateRequest, token: int):
        """Handle result from decode."""
        request.generated_tokens.append(token)

        # Check if finished
        if token == 1 or len(request.generated_tokens) >= request.max_tokens:
            request.is_finished = True
            self.completed.append(request)
            print(f"[Scheduler] Request {request.id} finished "
                  f"({len(request.generated_tokens)} tokens)")

    def has_work(self) -> bool:
        """Check if there's more work to do."""
        return bool(self.waiting_queue or self.running_batch)


class InferenceServer:
    """
    Main inference server orchestrating all components.
    """

    def __init__(self, max_batch_size: int = 4):
        self.tokenizer = SimpleTokenizer()
        self.model_runner = SimpleModelRunner()
        self.scheduler = Scheduler(max_batch_size)
        self.request_counter = 0

    async def generate(self, prompt: str, max_tokens: int = 50) -> GenerateRequest:
        """Submit a generation request."""
        # Tokenize
        tokens = self.tokenizer.encode(prompt)

        # Create request
        request = GenerateRequest(
            id=self.request_counter,
            prompt=prompt,
            prompt_tokens=tokens,
            max_tokens=max_tokens,
        )
        self.request_counter += 1

        # Submit to scheduler
        self.scheduler.add_request(request)

        return request

    async def run_step(self) -> bool:
        """Run one step of inference."""
        batch = self.scheduler.get_next_batch()
        if batch is None:
            return False

        if batch.is_prefill:
            # Prefill phase
            request = batch.requests[0]
            print(f"[Server] Prefill request {request.id} "
                  f"({len(request.prompt_tokens)} prompt tokens)")

            token = await self.model_runner.prefill(request)
            self.scheduler.process_prefill_result(request, token)

        else:
            # Decode phase
            print(f"[Server] Decode batch of {len(batch.requests)} requests")

            tokens = await self.model_runner.decode(batch.requests)
            for request, token in zip(batch.requests, tokens):
                self.scheduler.process_decode_result(request, token)

        return True

    async def run_until_complete(self):
        """Run until all requests are complete."""
        while self.scheduler.has_work():
            await self.run_step()


async def run_demo(num_requests: int, max_batch_size: int):
    """Run a demonstration of the inference server."""
    print("=" * 60)
    print(" MINIMAL INFERENCE SERVER DEMO")
    print("=" * 60)

    server = InferenceServer(max_batch_size=max_batch_size)

    # Sample prompts
    prompts = [
        "What is the capital of France?",
        "Explain quantum computing in simple terms.",
        "Write a haiku about programming.",
        "What is machine learning?",
        "Tell me a joke.",
        "How does the internet work?",
        "What is the meaning of life?",
        "Describe a beautiful sunset.",
    ]

    print(f"\nConfiguration:")
    print(f"  Max batch size: {max_batch_size}")
    print(f"  Number of requests: {num_requests}")
    print(f"\n{'─' * 60}\n")

    # Submit requests
    requests = []
    for i in range(num_requests):
        prompt = prompts[i % len(prompts)]
        request = await server.generate(prompt, max_tokens=20)
        requests.append(request)

    print(f"\n{'─' * 60}\n")
    print("Processing requests...\n")

    # Process all requests
    start_time = time.time()
    await server.run_until_complete()
    total_time = time.time() - start_time

    # Print results
    print(f"\n{'─' * 60}")
    print(" RESULTS")
    print(f"{'─' * 60}\n")

    total_tokens = 0
    for req in server.scheduler.completed:
        latency = time.time() - req.created_at
        print(f"Request {req.id}: {len(req.generated_tokens)} tokens, "
              f"{latency:.3f}s latency")
        total_tokens += len(req.generated_tokens)

    print(f"\n{'─' * 60}")
    print(" SUMMARY")
    print(f"{'─' * 60}")
    print(f"Total requests: {num_requests}")
    print(f"Total tokens generated: {total_tokens}")
    print(f"Total time: {total_time:.3f}s")
    print(f"Throughput: {total_tokens / total_time:.1f} tokens/second")

    # Explain what's happening
    print(f"\n{'─' * 60}")
    print(" WHAT THIS DEMONSTRATES")
    print(f"{'─' * 60}")
    print("""
1. REQUEST FLOW:
   Prompt → Tokenizer → Scheduler → Model Runner → Response

2. PREFILL vs DECODE:
   - Prefill: Process entire prompt (one request at a time here)
   - Decode: Generate tokens in batches

3. BATCHING:
   - Multiple requests share GPU compute during decode
   - Higher batch size → higher throughput but higher latency

4. CONTINUOUS BATCHING (simplified):
   - New requests can start prefill while others decode
   - Finished requests exit, making room for new ones

5. LIMITATIONS OF THIS DEMO:
   - No actual model (just simulated delays)
   - No KV cache management
   - No memory management
   - No streaming output
   - Simplified scheduling logic
""")


def main():
    parser = argparse.ArgumentParser(description="Minimal Inference Server Demo")
    parser.add_argument("--num-requests", "-n", type=int, default=5,
                        help="Number of requests to process")
    parser.add_argument("--batch-size", "-b", type=int, default=4,
                        help="Maximum batch size")
    args = parser.parse_args()

    asyncio.run(run_demo(args.num_requests, args.batch_size))


if __name__ == "__main__":
    main()