444 lines
11 KiB
Go
444 lines
11 KiB
Go
package http
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"git.kingecg.top/kingecg/gomog/internal/engine"
|
||
"git.kingecg.top/kingecg/gomog/pkg/types"
|
||
)
|
||
|
||
// HTTPServer HTTP 服务器
|
||
type HTTPServer struct {
|
||
mux *http.ServeMux
|
||
handler *RequestHandler
|
||
server *http.Server
|
||
}
|
||
|
||
// RequestHandler 请求处理器
|
||
type RequestHandler struct {
|
||
store *engine.MemoryStore
|
||
crud *engine.CRUDHandler
|
||
agg *engine.AggregationEngine
|
||
}
|
||
|
||
// NewRequestHandler 创建请求处理器
|
||
func NewRequestHandler(store *engine.MemoryStore, crud *engine.CRUDHandler, agg *engine.AggregationEngine) *RequestHandler {
|
||
return &RequestHandler{
|
||
store: store,
|
||
crud: crud,
|
||
agg: agg,
|
||
}
|
||
}
|
||
|
||
// NewHTTPServer 创建 HTTP 服务器
|
||
func NewHTTPServer(addr string, handler *RequestHandler) *HTTPServer {
|
||
s := &HTTPServer{
|
||
mux: http.NewServeMux(),
|
||
handler: handler,
|
||
}
|
||
|
||
// 注册路由
|
||
s.registerRoutes()
|
||
|
||
s.server = &http.Server{
|
||
Addr: addr,
|
||
Handler: s.mux,
|
||
}
|
||
|
||
return s
|
||
}
|
||
|
||
// Start 启动 HTTP 服务器
|
||
func (s *HTTPServer) Start() error {
|
||
return s.server.ListenAndServe()
|
||
}
|
||
|
||
// Shutdown 关闭 HTTP 服务器
|
||
func (s *HTTPServer) Shutdown(ctx context.Context) error {
|
||
return s.server.Shutdown(ctx)
|
||
}
|
||
|
||
// registerRoutes 注册路由
|
||
func (s *HTTPServer) registerRoutes() {
|
||
// API v1 路由
|
||
s.mux.HandleFunc("/api/v1/", s.handleAPI)
|
||
|
||
// 健康检查
|
||
s.mux.HandleFunc("/health", s.handleHealth)
|
||
|
||
// 根路径
|
||
s.mux.HandleFunc("/", s.handleRoot)
|
||
}
|
||
|
||
// handleRoot 根路径处理
|
||
func (s *HTTPServer) handleRoot(w http.ResponseWriter, r *http.Request) {
|
||
response := map[string]interface{}{
|
||
"name": "Gomog Server",
|
||
"version": "1.0.0",
|
||
"status": "running",
|
||
}
|
||
s.sendJSON(w, http.StatusOK, response)
|
||
}
|
||
|
||
// handleHealth 健康检查
|
||
func (s *HTTPServer) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||
response := map[string]interface{}{
|
||
"status": "healthy",
|
||
}
|
||
s.sendJSON(w, http.StatusOK, response)
|
||
}
|
||
|
||
// handleAPI 处理 API 请求
|
||
func (s *HTTPServer) handleAPI(w http.ResponseWriter, r *http.Request) {
|
||
// 解析路径:/api/v1/{database}/{collection}/{operation}
|
||
path := strings.TrimPrefix(r.URL.Path, "/api/v1/")
|
||
parts := strings.Split(path, "/")
|
||
|
||
if len(parts) < 3 {
|
||
s.sendError(w, http.StatusBadRequest, "Invalid path. Expected: /api/v1/{database}/{collection}/{operation}")
|
||
return
|
||
}
|
||
|
||
dbName := parts[0]
|
||
collection := parts[1]
|
||
operation := parts[2]
|
||
|
||
// 确保集合已加载到内存
|
||
if err := s.loadCollectionIfNeeded(dbName, collection); err != nil {
|
||
s.sendError(w, http.StatusInternalServerError, "Failed to load collection: "+err.Error())
|
||
return
|
||
}
|
||
|
||
// 根据操作类型分发请求
|
||
switch operation {
|
||
case "find":
|
||
s.handler.HandleFind(w, r, dbName, collection)
|
||
case "insert":
|
||
s.handler.HandleInsert(w, r, dbName, collection)
|
||
case "update":
|
||
s.handler.HandleUpdate(w, r, dbName, collection)
|
||
case "delete":
|
||
s.handler.HandleDelete(w, r, dbName, collection)
|
||
case "aggregate":
|
||
s.handler.HandleAggregate(w, r, dbName, collection)
|
||
default:
|
||
s.sendError(w, http.StatusBadRequest, "Unknown operation: "+operation)
|
||
}
|
||
}
|
||
|
||
// loadCollectionIfNeeded 按需加载集合
|
||
func (s *HTTPServer) loadCollectionIfNeeded(dbName, collection string) error {
|
||
// 简化处理:每次都尝试加载
|
||
// 实际应该检查是否已加载
|
||
fullCollection := dbName + "." + collection
|
||
_, err := s.handler.store.GetCollection(fullCollection)
|
||
if err == nil {
|
||
return nil // 已加载
|
||
}
|
||
|
||
// TODO: 从数据库加载集合
|
||
// return s.store.LoadCollection(context.Background(), fullCollection)
|
||
return nil
|
||
}
|
||
|
||
// sendJSON 发送 JSON 响应
|
||
func (s *HTTPServer) sendJSON(w http.ResponseWriter, status int, data interface{}) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(status)
|
||
json.NewEncoder(w).Encode(data)
|
||
}
|
||
|
||
// sendError 发送错误响应
|
||
func (s *HTTPServer) sendError(w http.ResponseWriter, status int, message string) {
|
||
s.sendJSON(w, status, map[string]interface{}{
|
||
"ok": 0,
|
||
"error": message,
|
||
"status": status,
|
||
})
|
||
}
|
||
|
||
// HandleFind 处理查询请求
|
||
func (h *RequestHandler) HandleFind(w http.ResponseWriter, r *http.Request, dbName, collection string) {
|
||
if r.Method != http.MethodPost && r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
var req types.FindRequest
|
||
if r.Method == http.MethodPost {
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
}
|
||
|
||
// 执行查询
|
||
fullCollection := dbName + "." + collection
|
||
docs, err := h.store.Find(fullCollection, req.Filter)
|
||
if err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 应用排序
|
||
if req.Sort != nil && len(req.Sort) > 0 {
|
||
// TODO: 实现排序逻辑
|
||
}
|
||
|
||
// 应用跳过和限制
|
||
skip := req.Skip
|
||
limit := req.Limit
|
||
|
||
if skip > 0 && skip < len(docs) {
|
||
docs = docs[skip:]
|
||
}
|
||
if limit > 0 && limit < len(docs) {
|
||
docs = docs[:limit]
|
||
}
|
||
|
||
// 应用投影
|
||
if req.Projection != nil && len(req.Projection) > 0 {
|
||
docs = applyProjection(docs, req.Projection)
|
||
}
|
||
|
||
response := types.Response{
|
||
OK: 1,
|
||
Cursor: &types.Cursor{
|
||
FirstBatch: docs,
|
||
ID: 0,
|
||
NS: dbName + "." + collection,
|
||
},
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(response)
|
||
}
|
||
|
||
// HandleInsert 处理插入请求
|
||
func (h *RequestHandler) HandleInsert(w http.ResponseWriter, r *http.Request, dbName, collection string) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
var req types.InsertRequest
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
fullCollection := dbName + "." + collection
|
||
insertedIDs := make(map[int]string)
|
||
|
||
for i, docData := range req.Documents {
|
||
// 生成 ID
|
||
id := generateID()
|
||
|
||
doc := types.Document{
|
||
ID: id,
|
||
Data: docData,
|
||
CreatedAt: time.Now(),
|
||
UpdatedAt: time.Now(),
|
||
}
|
||
|
||
// 插入到内存
|
||
if err := h.store.Insert(fullCollection, doc); err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
insertedIDs[i] = id
|
||
}
|
||
|
||
response := types.InsertResult{
|
||
OK: 1,
|
||
N: len(req.Documents),
|
||
InsertedIDs: insertedIDs,
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(response)
|
||
}
|
||
|
||
// HandleUpdate 处理更新请求
|
||
func (h *RequestHandler) HandleUpdate(w http.ResponseWriter, r *http.Request, dbName, collection string) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
var req types.UpdateRequest
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
fullCollection := dbName + "." + collection
|
||
totalMatched := 0
|
||
totalModified := 0
|
||
upserted := make([]types.UpsertID, 0)
|
||
|
||
for _, op := range req.Updates {
|
||
matched, modified, upsertedIDs, err := h.store.Update(fullCollection, op.Q, op.U, op.Upsert, op.ArrayFilters)
|
||
if err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
totalMatched += matched
|
||
totalModified += modified
|
||
|
||
// 收集 upserted IDs
|
||
for _, id := range upsertedIDs {
|
||
upserted = append(upserted, types.UpsertID{
|
||
Index: 0,
|
||
ID: id,
|
||
})
|
||
}
|
||
}
|
||
|
||
response := types.UpdateResult{
|
||
OK: 1,
|
||
N: totalMatched,
|
||
NModified: totalModified,
|
||
Upserted: upserted,
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(response)
|
||
}
|
||
|
||
// HandleDelete 处理删除请求
|
||
func (h *RequestHandler) HandleDelete(w http.ResponseWriter, r *http.Request, dbName, collection string) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
var req types.DeleteRequest
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
fullCollection := dbName + "." + collection
|
||
totalDeleted := 0
|
||
|
||
for _, op := range req.Deletes {
|
||
deleted, err := h.store.Delete(fullCollection, op.Q)
|
||
if err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
totalDeleted += deleted
|
||
|
||
// 如果 limit=1,只删除第一个匹配的文档
|
||
if op.Limit == 1 && deleted > 0 {
|
||
break
|
||
}
|
||
}
|
||
|
||
response := types.DeleteResult{
|
||
OK: 1,
|
||
N: totalDeleted,
|
||
DeletedCount: totalDeleted,
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(response)
|
||
}
|
||
|
||
// HandleAggregate 处理聚合请求
|
||
func (h *RequestHandler) HandleAggregate(w http.ResponseWriter, r *http.Request, dbName, collection string) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
var req types.AggregateRequest
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
fullCollection := dbName + "." + collection
|
||
results, err := h.agg.Execute(fullCollection, req.Pipeline)
|
||
if err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
response := types.AggregateResult{
|
||
OK: 1,
|
||
Result: results,
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(response)
|
||
}
|
||
|
||
// applyProjection 应用投影
|
||
func applyProjection(docs []types.Document, projection types.Projection) []types.Document {
|
||
result := make([]types.Document, len(docs))
|
||
for i, doc := range docs {
|
||
projected := make(map[string]interface{})
|
||
|
||
// 简单实现:只包含指定的字段
|
||
for field, include := range projection {
|
||
if isTrue(include) && field != "_id" {
|
||
projected[field] = getNestedValue(doc.Data, field)
|
||
}
|
||
}
|
||
|
||
// 总是包含 _id 除非明确排除
|
||
if excludeID, ok := projection["_id"]; !ok || isTrue(excludeID) {
|
||
projected["_id"] = doc.ID
|
||
}
|
||
|
||
result[i] = types.Document{
|
||
ID: doc.ID,
|
||
Data: projected,
|
||
}
|
||
}
|
||
return result
|
||
}
|
||
|
||
// generateID 生成唯一 ID(简化版本)
|
||
func generateID() string {
|
||
return engine.GenerateID()
|
||
}
|
||
|
||
// isTrue 检查值是否为真
|
||
func isTrue(v interface{}) bool {
|
||
switch val := v.(type) {
|
||
case bool:
|
||
return val
|
||
case int:
|
||
return val != 0
|
||
case float64:
|
||
return val != 0
|
||
}
|
||
return true
|
||
}
|
||
|
||
// getNestedValue 获取嵌套字段值
|
||
func getNestedValue(doc map[string]interface{}, key string) interface{} {
|
||
parts := strings.Split(key, ".")
|
||
var current interface{} = doc
|
||
|
||
for _, part := range parts {
|
||
if m, ok := current.(map[string]interface{}); ok {
|
||
current = m[part]
|
||
} else {
|
||
return nil
|
||
}
|
||
}
|
||
|
||
return current
|
||
}
|