package engine import ( "fmt" "math/rand" "sort" "strings" "time" "git.kingecg.top/kingecg/gomog/pkg/types" ) // executeReplaceRoot 执行 $replaceRoot 阶段 func (e *AggregationEngine) executeReplaceRoot(spec interface{}, docs []types.Document) ([]types.Document, error) { specMap, ok := spec.(map[string]interface{}) if !ok { return docs, nil } newRootRaw, exists := specMap["newRoot"] if !exists { return docs, nil } var results []types.Document for _, doc := range docs { newRoot := e.evaluateExpression(doc.Data, newRootRaw) if newRootMap, ok := newRoot.(map[string]interface{}); ok { results = append(results, types.Document{ ID: doc.ID, Data: newRootMap, CreatedAt: doc.CreatedAt, UpdatedAt: doc.UpdatedAt, }) } else { // 如果不是对象,创建包装文档 results = append(results, types.Document{ ID: doc.ID, Data: map[string]interface{}{"value": newRoot}, CreatedAt: doc.CreatedAt, UpdatedAt: doc.UpdatedAt, }) } } return results, nil } // executeReplaceWith 执行 $replaceWith 阶段($replaceRoot 的别名) func (e *AggregationEngine) executeReplaceWith(spec interface{}, docs []types.Document) ([]types.Document, error) { // $replaceWith 是 $replaceRoot 的简写形式 // spec 本身就是 newRoot 表达式 var results []types.Document for _, doc := range docs { newRoot := e.evaluateExpression(doc.Data, spec) if newRootMap, ok := newRoot.(map[string]interface{}); ok { results = append(results, types.Document{ ID: doc.ID, Data: newRootMap, CreatedAt: doc.CreatedAt, UpdatedAt: doc.UpdatedAt, }) } else { // 如果不是对象,创建包装文档 results = append(results, types.Document{ ID: doc.ID, Data: map[string]interface{}{"value": newRoot}, CreatedAt: doc.CreatedAt, UpdatedAt: doc.UpdatedAt, }) } } return results, nil } // executeGraphLookup 执行 $graphLookup 阶段(递归查找) func (e *AggregationEngine) executeGraphLookup(spec interface{}, docs []types.Document) ([]types.Document, error) { specMap, ok := spec.(map[string]interface{}) if !ok { return docs, nil } from, _ := specMap["from"].(string) startWith := specMap["startWith"] connectFromField, _ := specMap["connectFromField"].(string) connectToField, _ := specMap["connectToField"].(string) as, _ := specMap["as"].(string) maxDepthRaw, _ := specMap["maxDepth"].(float64) restrictSearchWithMatchRaw, _ := specMap["restrictSearchWithMatch"] if as == "" || connectFromField == "" || connectToField == "" { return docs, nil } maxDepth := int(maxDepthRaw) if maxDepth == 0 { maxDepth = -1 // 无限制 } var results []types.Document for _, doc := range docs { // 计算起始值 startValue := e.evaluateExpression(doc.Data, startWith) // 递归查找 connectedDocs := e.graphLookupRecursive( from, startValue, connectFromField, connectToField, maxDepth, restrictSearchWithMatchRaw, make(map[string]bool), ) // 添加结果数组 newDoc := make(map[string]interface{}) for k, v := range doc.Data { newDoc[k] = v } newDoc[as] = connectedDocs results = append(results, types.Document{ ID: doc.ID, Data: newDoc, CreatedAt: doc.CreatedAt, UpdatedAt: doc.UpdatedAt, }) } return results, nil } // graphLookupRecursive 递归查找关联文档 func (e *AggregationEngine) graphLookupRecursive( collection string, startValue interface{}, connectFromField string, connectToField string, maxDepth int, restrictSearchWithMatch interface{}, visited map[string]bool, ) []map[string]interface{} { var results []map[string]interface{} if maxDepth == 0 { return results } // 获取目标集合 targetCollection := e.store.collections[collection] if targetCollection == nil { return results } // 查找匹配的文档 for docID, doc := range targetCollection.documents { // 避免循环引用 if visited[docID] { continue } // 检查是否匹配 docValue := getNestedValue(doc.Data, connectToField) if !valuesEqual(startValue, docValue) { continue } // 应用 restrictSearchWithMatch 过滤 if restrictSearchWithMatch != nil { if matchSpec, ok := restrictSearchWithMatch.(map[string]interface{}); ok { if !MatchFilter(doc.Data, matchSpec) { continue } } } // 标记为已访问 visited[docID] = true // 添加到结果 docCopy := make(map[string]interface{}) for k, v := range doc.Data { docCopy[k] = v } results = append(results, docCopy) // 递归查找下一级 nextValue := getNestedValue(doc.Data, connectFromField) moreResults := e.graphLookupRecursive( collection, nextValue, connectFromField, connectToField, maxDepth-1, restrictSearchWithMatch, visited, ) results = append(results, moreResults...) } return results } // executeSetWindowFields 执行 $setWindowFields 阶段(窗口函数) func (e *AggregationEngine) executeSetWindowFields(spec interface{}, docs []types.Document) ([]types.Document, error) { specMap, ok := spec.(map[string]interface{}) if !ok { return docs, nil } outputsRaw, _ := specMap["output"].(map[string]interface{}) partitionByRaw, _ := specMap["partitionBy"] sortByRaw, _ := specMap["sortBy"].(map[string]interface{}) if outputsRaw == nil { return docs, nil } // 分组(分区) partitions := make(map[string][]types.Document) for _, doc := range docs { var key string if partitionByRaw != nil { partitionKey := e.evaluateExpression(doc.Data, partitionByRaw) key = fmt.Sprintf("%v", partitionKey) } else { key = "all" } partitions[key] = append(partitions[key], doc) } // 对每个分区排序 for key := range partitions { if sortByRaw != nil && len(sortByRaw) > 0 { sortDocsBySpec(partitions[key], sortByRaw) } } // 应用窗口函数 var results []types.Document for _, partition := range partitions { for i, doc := range partition { newDoc := make(map[string]interface{}) for k, v := range doc.Data { newDoc[k] = v } // 计算每个输出字段 for fieldName, windowSpecRaw := range outputsRaw { windowSpec, ok := windowSpecRaw.(map[string]interface{}) if !ok { continue } value := e.calculateWindowValue(windowSpec, partition, i, doc) newDoc[fieldName] = value } results = append(results, types.Document{ ID: doc.ID, Data: newDoc, CreatedAt: doc.CreatedAt, UpdatedAt: doc.UpdatedAt, }) } } return results, nil } // calculateWindowValue 计算窗口函数值 func (e *AggregationEngine) calculateWindowValue( windowSpec map[string]interface{}, partition []types.Document, currentIndex int, currentDoc types.Document, ) interface{} { // 解析窗口操作符 for op, operand := range windowSpec { switch op { case "$documentNumber": return float64(currentIndex + 1) case "$rank": return float64(currentIndex + 1) case "$first": expr := e.evaluateExpression(partition[0].Data, operand) return expr case "$last": expr := e.evaluateExpression(partition[len(partition)-1].Data, operand) return expr case "$shift": n := int(toFloat64(operand)) targetIndex := currentIndex + n if targetIndex < 0 || targetIndex >= len(partition) { return nil } return partition[targetIndex].Data case "$fillDefault": val := e.evaluateExpression(currentDoc.Data, operand) if val == nil { return 0 // 默认值 } return val case "$sum", "$avg", "$min", "$max": // 聚合窗口函数 return e.aggregateWindow(op, operand, partition, currentIndex) default: // 普通表达式 return e.evaluateExpression(currentDoc.Data, windowSpec) } } return nil } // aggregateWindow 聚合窗口函数 func (e *AggregationEngine) aggregateWindow( op string, operand interface{}, partition []types.Document, currentIndex int, ) interface{} { var values []float64 for i, doc := range partition { // 根据窗口范围决定是否包含 windowSpec := getWindowRange(op, operand) if !inWindow(i, currentIndex, windowSpec) { continue } val := e.evaluateExpression(doc.Data, operand) if num, ok := toNumber(val); ok { values = append(values, num) } } if len(values) == 0 { return nil } switch op { case "$sum": sum := 0.0 for _, v := range values { sum += v } return sum case "$avg": sum := 0.0 for _, v := range values { sum += v } return sum / float64(len(values)) case "$min": min := values[0] for _, v := range values[1:] { if v < min { min = v } } return min case "$max": max := values[0] for _, v := range values[1:] { if v > max { max = v } } return max default: return nil } } // getWindowRange 获取窗口范围 func getWindowRange(op string, operand interface{}) map[string]interface{} { // 简化实现:默认使用整个分区 return map[string]interface{}{"window": "all"} } // inWindow 检查索引是否在窗口内 func inWindow(index, current int, windowSpec map[string]interface{}) bool { // 简化实现:包含所有索引 return true } // executeTextSearch 执行 $text 文本搜索 func (e *AggregationEngine) executeTextSearch(docs []types.Document, search string, language string, caseSensitive bool) ([]types.Document, error) { var results []types.Document // 分词搜索 searchTerms := strings.Fields(strings.ToLower(search)) for _, doc := range docs { score := e.calculateTextScore(doc.Data, searchTerms, caseSensitive) if score > 0 { // 添加文本得分 newDoc := make(map[string]interface{}) for k, v := range doc.Data { newDoc[k] = v } newDoc["_textScore"] = score results = append(results, types.Document{ ID: doc.ID, Data: newDoc, CreatedAt: doc.CreatedAt, UpdatedAt: doc.UpdatedAt, }) } } // 按文本得分排序 sort.Slice(results, func(i, j int) bool { scoreI := results[i].Data["_textScore"].(float64) scoreJ := results[j].Data["_textScore"].(float64) return scoreI > scoreJ }) return results, nil } // calculateTextScore 计算文本匹配得分 func (e *AggregationEngine) calculateTextScore(doc map[string]interface{}, searchTerms []string, caseSensitive bool) float64 { score := 0.0 // 递归搜索所有字符串字段 e.searchInValue(doc, searchTerms, caseSensitive, &score) return score } // searchInValue 在值中搜索 func (e *AggregationEngine) searchInValue(value interface{}, searchTerms []string, caseSensitive bool, score *float64) { switch v := value.(type) { case string: if !caseSensitive { v = strings.ToLower(v) } for _, term := range searchTerms { searchTerm := term if !caseSensitive { searchTerm = strings.ToLower(term) } if strings.Contains(v, searchTerm) { *score += 1.0 } } case []interface{}: for _, item := range v { e.searchInValue(item, searchTerms, caseSensitive, score) } case map[string]interface{}: for _, val := range v { e.searchInValue(val, searchTerms, caseSensitive, score) } } } // sortDocsBySpec 根据规范对文档排序 func sortDocsBySpec(docs []types.Document, sortByRaw map[string]interface{}) { type sortKeys struct { doc types.Document keys []float64 } keys := make([]sortKeys, len(docs)) for i, doc := range docs { var docKeys []float64 for _, fieldRaw := range sortByRaw { field := getFieldValueStrFromDoc(doc, fieldRaw) if num, ok := toNumber(field); ok { docKeys = append(docKeys, num) } else { docKeys = append(docKeys, 0) } } keys[i] = sortKeys{doc: doc, keys: docKeys} } sort.Slice(keys, func(i, j int) bool { for k := range keys[i].keys { if keys[i].keys[k] != keys[j].keys[k] { return keys[i].keys[k] < keys[j].keys[k] } } return false }) for i, k := range keys { docs[i] = k.doc } } // getFieldValueStrFromDoc 从文档获取字段值 func getFieldValueStrFromDoc(doc types.Document, fieldRaw interface{}) interface{} { if fieldStr, ok := fieldRaw.(string); ok { return getNestedValue(doc.Data, fieldStr) } return fieldRaw } // valuesEqual 比较两个值是否相等 func valuesEqual(a, b interface{}) bool { if a == nil && b == nil { return true } if a == nil || b == nil { return false } return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b) } // getRandomDocuments 随机获取指定数量的文档 func getRandomDocuments(docs []types.Document, n int) []types.Document { if n >= len(docs) { return docs } // 随机打乱 rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(docs), func(i, j int) { docs[i], docs[j] = docs[j], docs[i] }) return docs[:n] }