#!/usr/bin/env python3 import argparse import json import random import re import time import sys import os import threading from http.server import HTTPServer, BaseHTTPRequestHandler from typing import Dict, List, Optional from dataclasses import dataclass from pathlib import Path import datasets # Set cache directory for HuggingFace datasets cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" cache_dir.mkdir(parents=True, exist_ok=True) os.environ["HF_DATASETS_CACHE"] = str(cache_dir) def dice(s1: str, s2: str) -> float: """Calculate Dice coefficient between two strings based on bigram overlap.""" if not s1 and not s2: return 1.0 def _bigrams(s: str): return [s[i : i + 2] for i in range(len(s) - 1)] bigrams1 = _bigrams(s1) bigrams2 = _bigrams(s2) if not bigrams1 and not bigrams2: return 1.0 from collections import Counter freq1 = Counter(bigrams1) freq2 = Counter(bigrams2) intersection = sum(min(freq1[bg], freq2[bg]) for bg in freq1) dice_coeff = 2 * intersection / (len(bigrams1) + len(bigrams2)) return dice_coeff def debug_log(message: str): """Log debug messages to both stdout and a file""" print(message, file=sys.stderr) with open("/tmp/simulator-debug.log", "a") as f: f.write(message + "\n") simulator: Optional["Simulator"] = None @dataclass class EvalState: id: str tasks: List[str] task_states: Dict[str, Dict] sampling_config: Dict def normalize_number(s: str) -> Optional[int]: match = re.match(r"\d+", s) # match digits from the start if not match: return None return int(match.group(0)) class AimeDataset: def __init__(self, split: str = "train"): self.split = split self.questions: List[Dict] = [] self._load_dataset() def _load_dataset(self): print(f"Loading AIME dataset (split: {self.split})...") cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0" if cache_path.exists(): print(f"Using cached dataset from {cache_path}") ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path)) else: ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split) self.questions = list(ds) print(f"AIME dataset loaded: {len(self.questions)} questions") def find_question(self, request_text: str) -> Optional[Dict]: best_match = None best_distance = -1 best_index = -1 for i, question in enumerate(self.questions): question_text = question["problem"] request_lower = request_text.lower() question_lower = question_text.lower() # Exact match if question_lower == request_lower: debug_log(f"DEBUG: Found exact match at index {i}") return question # Remove LaTeX formatting for more flexible matching question_no_latex = re.sub(r'\$[^$]+\$', '', question_text) if question_no_latex.lower() == request_lower: debug_log(f"DEBUG: Found match (no LaTeX) at index {i}") return question # Calculate Dice coefficient for partial matches # Only consider if request is at least 50% of question length if len(request_lower) >= len(question_lower) * 0.5: distance = dice(question_lower, request_lower) if distance > best_distance: best_distance = distance best_match = question best_index = i if best_match and best_distance > 0.3: # Threshold for partial match debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}") return best_match debug_log(f"DEBUG: No matching question found for: {request_text[:100]}...") return None def get_answer(self, question: Dict) -> str: answer = question["answer"] if isinstance(answer, str): normalized = normalize_number(answer) return str(normalized) if normalized is not None else answer return str(answer) class Simulator: def __init__( self, port: int = 8033, host: str = "localhost", success_rate: float = 0.8, dataset_split: str = "train" ): self.port = port self.host = host self.success_rate = success_rate self.dataset = AimeDataset(dataset_split) self.eval_state = EvalState( id="aime-2025", tasks=["aime"], task_states={}, sampling_config={"temperature": 0, "max_tokens": 2048} ) def _generate_response( self, question: Dict, should_be_correct: bool ) -> Dict: expected_answer = self.dataset.get_answer(question) if should_be_correct: response_text = expected_answer else: response_text = self._generate_wrong_answer(question) return { "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion", "created": int(time.time()), "model": "llama", "choices": [ { "index": 0, "message": { "role": "assistant", "content": response_text }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150 } } def _generate_wrong_answer(self, question: Dict) -> str: expected_answer = self.dataset.get_answer(question) if expected_answer.isdigit(): wrong_answer = str(int(expected_answer) + 1) else: wrong_answer = expected_answer + " (wrong)" return wrong_answer def _process_request(self, request_data: Dict) -> Dict: messages = request_data.get("messages", []) if not messages: return {"error": "No messages in request"} request_text = messages[0].get("content", "") debug_log(f"DEBUG: Received request with content: {request_text[:150]}...") question = self.dataset.find_question(request_text) if not question: debug_log(f"DEBUG: find_question returned None") return {"error": "No matching question found"} should_be_correct = random.random() < self.success_rate response = self._generate_response(question, should_be_correct) task_id = "aime" self.eval_state.task_states[task_id] = { "correct": should_be_correct, "expected": self.dataset.get_answer(question), "predicted": response["choices"][0]["message"]["content"] } return response class RequestHandler(BaseHTTPRequestHandler): def do_POST(self): if self.path != "/v1/chat/completions": self._send_json({"error": "Not found"}, 404) return try: content_length = int(self.headers.get("Content-Length", 0)) body = self.rfile.read(content_length) request_data = json.loads(body) if body else None if not request_data: self._send_json({"error": "Invalid JSON"}, 400) return if simulator is None: self._send_json({"error": "Simulator not initialized"}, 500) return response = simulator._process_request(request_data) self._send_json(response, 200) except json.JSONDecodeError: self._send_json({"error": "Invalid JSON"}, 400) except Exception as e: print(f"Error processing request: {e}") self._send_json({"error": str(e)}, 500) def _send_json(self, data: dict, status: int = 200): body = json.dumps(data).encode("utf-8") self.send_response(status) self.send_header("Content-Type", "application/json") self.send_header("Content-Length", str(len(body))) self.end_headers() self.wfile.write(body) def log_message(self, format, *args): # Suppress default request logging pass def main(): parser = argparse.ArgumentParser( description="llama-server simulator for testing eval scripts" ) parser.add_argument( "--port", type=int, default=8033, help="Server port (default: 8033)" ) parser.add_argument( "--host", type=str, default="localhost", help="Server host (default: localhost)" ) parser.add_argument( "--success-rate", type=float, default=0.8, help="Success rate 0-1 (default: 0.8)" ) parser.add_argument( "--dataset-split", type=str, default="train", help="AIME dataset split to use (default: train)" ) args = parser.parse_args() global simulator simulator = Simulator( port=args.port, host=args.host, success_rate=args.success_rate, dataset_split=args.dataset_split ) server = HTTPServer((args.host, args.port), RequestHandler) server_thread = threading.Thread(target=server.serve_forever, daemon=True) server_thread.start() print("\n=== llama-server-simulator ===") print(f"Server running on http://{args.host}:{args.port}") print(f"Success rate: {args.success_rate}") print(f"AIME dataset loaded: {len(simulator.dataset.questions)} questions") print("\nPress Ctrl+C to stop\n") try: server_thread.join() except KeyboardInterrupt: print("\nShutting down...") server.shutdown() if __name__ == "__main__": main()