gomog/internal/protocol/http/server.go

444 lines
11 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package http
import (
"context"
"encoding/json"
"net/http"
"strings"
"time"
"git.kingecg.top/kingecg/gomog/internal/engine"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// HTTPServer HTTP 服务器
type HTTPServer struct {
mux *http.ServeMux
handler *RequestHandler
server *http.Server
}
// RequestHandler 请求处理器
type RequestHandler struct {
store *engine.MemoryStore
crud *engine.CRUDHandler
agg *engine.AggregationEngine
}
// NewRequestHandler 创建请求处理器
func NewRequestHandler(store *engine.MemoryStore, crud *engine.CRUDHandler, agg *engine.AggregationEngine) *RequestHandler {
return &RequestHandler{
store: store,
crud: crud,
agg: agg,
}
}
// NewHTTPServer 创建 HTTP 服务器
func NewHTTPServer(addr string, handler *RequestHandler) *HTTPServer {
s := &HTTPServer{
mux: http.NewServeMux(),
handler: handler,
}
// 注册路由
s.registerRoutes()
s.server = &http.Server{
Addr: addr,
Handler: s.mux,
}
return s
}
// Start 启动 HTTP 服务器
func (s *HTTPServer) Start() error {
return s.server.ListenAndServe()
}
// Shutdown 关闭 HTTP 服务器
func (s *HTTPServer) Shutdown(ctx context.Context) error {
return s.server.Shutdown(ctx)
}
// registerRoutes 注册路由
func (s *HTTPServer) registerRoutes() {
// API v1 路由
s.mux.HandleFunc("/api/v1/", s.handleAPI)
// 健康检查
s.mux.HandleFunc("/health", s.handleHealth)
// 根路径
s.mux.HandleFunc("/", s.handleRoot)
}
// handleRoot 根路径处理
func (s *HTTPServer) handleRoot(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"name": "Gomog Server",
"version": "1.0.0",
"status": "running",
}
s.sendJSON(w, http.StatusOK, response)
}
// handleHealth 健康检查
func (s *HTTPServer) handleHealth(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"status": "healthy",
}
s.sendJSON(w, http.StatusOK, response)
}
// handleAPI 处理 API 请求
func (s *HTTPServer) handleAPI(w http.ResponseWriter, r *http.Request) {
// 解析路径:/api/v1/{database}/{collection}/{operation}
path := strings.TrimPrefix(r.URL.Path, "/api/v1/")
parts := strings.Split(path, "/")
if len(parts) < 3 {
s.sendError(w, http.StatusBadRequest, "Invalid path. Expected: /api/v1/{database}/{collection}/{operation}")
return
}
dbName := parts[0]
collection := parts[1]
operation := parts[2]
// 确保集合已加载到内存
if err := s.loadCollectionIfNeeded(dbName, collection); err != nil {
s.sendError(w, http.StatusInternalServerError, "Failed to load collection: "+err.Error())
return
}
// 根据操作类型分发请求
switch operation {
case "find":
s.handler.HandleFind(w, r, dbName, collection)
case "insert":
s.handler.HandleInsert(w, r, dbName, collection)
case "update":
s.handler.HandleUpdate(w, r, dbName, collection)
case "delete":
s.handler.HandleDelete(w, r, dbName, collection)
case "aggregate":
s.handler.HandleAggregate(w, r, dbName, collection)
default:
s.sendError(w, http.StatusBadRequest, "Unknown operation: "+operation)
}
}
// loadCollectionIfNeeded 按需加载集合
func (s *HTTPServer) loadCollectionIfNeeded(dbName, collection string) error {
// 简化处理:每次都尝试加载
// 实际应该检查是否已加载
fullCollection := dbName + "." + collection
_, err := s.handler.store.GetCollection(fullCollection)
if err == nil {
return nil // 已加载
}
// TODO: 从数据库加载集合
// return s.store.LoadCollection(context.Background(), fullCollection)
return nil
}
// sendJSON 发送 JSON 响应
func (s *HTTPServer) sendJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
// sendError 发送错误响应
func (s *HTTPServer) sendError(w http.ResponseWriter, status int, message string) {
s.sendJSON(w, status, map[string]interface{}{
"ok": 0,
"error": message,
"status": status,
})
}
// HandleFind 处理查询请求
func (h *RequestHandler) HandleFind(w http.ResponseWriter, r *http.Request, dbName, collection string) {
if r.Method != http.MethodPost && r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req types.FindRequest
if r.Method == http.MethodPost {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
}
// 执行查询
fullCollection := dbName + "." + collection
docs, err := h.store.Find(fullCollection, req.Filter)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// 应用排序
if req.Sort != nil && len(req.Sort) > 0 {
// TODO: 实现排序逻辑
}
// 应用跳过和限制
skip := req.Skip
limit := req.Limit
if skip > 0 && skip < len(docs) {
docs = docs[skip:]
}
if limit > 0 && limit < len(docs) {
docs = docs[:limit]
}
// 应用投影
if req.Projection != nil && len(req.Projection) > 0 {
docs = applyProjection(docs, req.Projection)
}
response := types.Response{
OK: 1,
Cursor: &types.Cursor{
FirstBatch: docs,
ID: 0,
NS: dbName + "." + collection,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// HandleInsert 处理插入请求
func (h *RequestHandler) HandleInsert(w http.ResponseWriter, r *http.Request, dbName, collection string) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req types.InsertRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
fullCollection := dbName + "." + collection
insertedIDs := make(map[int]string)
for i, docData := range req.Documents {
// 生成 ID
id := generateID()
doc := types.Document{
ID: id,
Data: docData,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 插入到内存
if err := h.store.Insert(fullCollection, doc); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
insertedIDs[i] = id
}
response := types.InsertResult{
OK: 1,
N: len(req.Documents),
InsertedIDs: insertedIDs,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// HandleUpdate 处理更新请求
func (h *RequestHandler) HandleUpdate(w http.ResponseWriter, r *http.Request, dbName, collection string) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req types.UpdateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
fullCollection := dbName + "." + collection
totalMatched := 0
totalModified := 0
upserted := make([]types.UpsertID, 0)
for _, op := range req.Updates {
matched, modified, upsertedIDs, err := h.store.Update(fullCollection, op.Q, op.U, op.Upsert, op.ArrayFilters)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
totalMatched += matched
totalModified += modified
// 收集 upserted IDs
for _, id := range upsertedIDs {
upserted = append(upserted, types.UpsertID{
Index: 0,
ID: id,
})
}
}
response := types.UpdateResult{
OK: 1,
N: totalMatched,
NModified: totalModified,
Upserted: upserted,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// HandleDelete 处理删除请求
func (h *RequestHandler) HandleDelete(w http.ResponseWriter, r *http.Request, dbName, collection string) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req types.DeleteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
fullCollection := dbName + "." + collection
totalDeleted := 0
for _, op := range req.Deletes {
deleted, err := h.store.Delete(fullCollection, op.Q)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
totalDeleted += deleted
// 如果 limit=1只删除第一个匹配的文档
if op.Limit == 1 && deleted > 0 {
break
}
}
response := types.DeleteResult{
OK: 1,
N: totalDeleted,
DeletedCount: totalDeleted,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// HandleAggregate 处理聚合请求
func (h *RequestHandler) HandleAggregate(w http.ResponseWriter, r *http.Request, dbName, collection string) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req types.AggregateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
fullCollection := dbName + "." + collection
results, err := h.agg.Execute(fullCollection, req.Pipeline)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
response := types.AggregateResult{
OK: 1,
Result: results,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// applyProjection 应用投影
func applyProjection(docs []types.Document, projection types.Projection) []types.Document {
result := make([]types.Document, len(docs))
for i, doc := range docs {
projected := make(map[string]interface{})
// 简单实现:只包含指定的字段
for field, include := range projection {
if isTrue(include) && field != "_id" {
projected[field] = getNestedValue(doc.Data, field)
}
}
// 总是包含 _id 除非明确排除
if excludeID, ok := projection["_id"]; !ok || isTrue(excludeID) {
projected["_id"] = doc.ID
}
result[i] = types.Document{
ID: doc.ID,
Data: projected,
}
}
return result
}
// generateID 生成唯一 ID简化版本
func generateID() string {
return engine.GenerateID()
}
// isTrue 检查值是否为真
func isTrue(v interface{}) bool {
switch val := v.(type) {
case bool:
return val
case int:
return val != 0
case float64:
return val != 0
}
return true
}
// getNestedValue 获取嵌套字段值
func getNestedValue(doc map[string]interface{}, key string) interface{} {
parts := strings.Split(key, ".")
var current interface{} = doc
for _, part := range parts {
if m, ok := current.(map[string]interface{}); ok {
current = m[part]
} else {
return nil
}
}
return current
}