Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions model/ability.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ func (channel *Channel) AddAbilities() error {
models_ := strings.Split(channel.Models, ",")
models_ = utils.DeDuplication(models_)
groups_ := strings.Split(channel.Group, ",")
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {

// Expand models to include aliases for standard names
expandedModels := expandModelsWithAliases(models_, channel.Type)

abilities := make([]Ability, 0, len(expandedModels))
for _, model := range expandedModels {
for _, group := range groups_ {
ability := Ability{
Group: group,
Expand All @@ -70,6 +74,73 @@ func (channel *Channel) AddAbilities() error {
return DB.Create(&abilities).Error
}

// expandModelsWithAliases expands model list to include standard names for channel-specific models
func expandModelsWithAliases(models []string, channelType int) []string {
expandedModels := make([]string, 0)
modelSet := make(map[string]bool)

for _, model := range models {
model = strings.TrimSpace(model)
if model == "" {
continue
}

// Add original model
if !modelSet[model] {
expandedModels = append(expandedModels, model)
modelSet[model] = true
}

// Try to find standard name for this channel-specific model
standardName := getStandardModelNameForChannel(model, channelType)
if standardName != model && !modelSet[standardName] {
expandedModels = append(expandedModels, standardName)
modelSet[standardName] = true
}
}

return expandedModels
}

// getStandardModelNameForChannel returns the standard name for a channel-specific model
func getStandardModelNameForChannel(actualName string, channelType int) string {
// Import alias mapping (we'll create a lightweight version here to avoid circular imports)
aliasMap := getModelAliasesForChannelType(channelType)

for standard, actual := range aliasMap {
if actual == actualName {
return standard
}
}

return actualName
}

// getModelAliasesForChannelType returns model aliases for specific channel type
// This is a lightweight version to avoid importing the full alias module
func getModelAliasesForChannelType(channelType int) map[string]string {
switch channelType {
case 24: // OpenRouter
return map[string]string{
"gpt-4o": "openai/gpt-4o",
"gpt-4o-mini": "openai/gpt-4o-mini",
"gpt-4": "openai/gpt-4",
"gpt-4-turbo": "openai/gpt-4-turbo",
"gpt-3.5-turbo": "openai/gpt-3.5-turbo",
"o1": "openai/o1",
"o1-mini": "openai/o1-mini",
"o1-preview": "openai/o1-preview",
"claude-3-haiku": "anthropic/claude-3-haiku",
"claude-3-sonnet": "anthropic/claude-3-sonnet",
"claude-3-opus": "anthropic/claude-3-opus",
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet",
"claude-3.5-haiku": "anthropic/claude-3.5-haiku",
}
default:
return map[string]string{}
}
}

func (channel *Channel) DeleteAbilities() error {
return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
}
Expand Down
67 changes: 67 additions & 0 deletions relay/billing/ratio/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -692,23 +692,90 @@ func GetModelRatio(name string, channelType int) float64 {
if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}

// Try channel-specific model ratio first
model := fmt.Sprintf("%s(%d)", name, channelType)
if ratio, ok := ModelRatio[model]; ok {
return ratio
}
if ratio, ok := DefaultModelRatio[model]; ok {
return ratio
}

// Try direct model name
if ratio, ok := ModelRatio[name]; ok {
return ratio
}
if ratio, ok := DefaultModelRatio[name]; ok {
return ratio
}

// Try to find standard model name for alias lookup
standardName := getStandardModelNameForBilling(name, channelType)
if standardName != name {
// Try standard model name
if ratio, ok := ModelRatio[standardName]; ok {
return ratio
}
if ratio, ok := DefaultModelRatio[standardName]; ok {
return ratio
}

// Try standard model with channel type
standardModel := fmt.Sprintf("%s(%d)", standardName, channelType)
if ratio, ok := ModelRatio[standardModel]; ok {
return ratio
}
if ratio, ok := DefaultModelRatio[standardModel]; ok {
return ratio
}
}

logger.SysError("model ratio not found: " + name)
return 30
}

// getStandardModelNameForBilling returns the standard model name for billing lookup
func getStandardModelNameForBilling(actualName string, channelType int) string {
// Reverse alias mapping for billing
aliasMap := getBillingAliasMap(channelType)

for standard, actual := range aliasMap {
if actual == actualName {
return standard
}
}

return actualName
}

// getBillingAliasMap returns alias mapping for billing purposes
func getBillingAliasMap(channelType int) map[string]string {
switch channelType {
case 24: // OpenRouter
return map[string]string{
"gpt-4o": "openai/gpt-4o",
"gpt-4o-mini": "openai/gpt-4o-mini",
"gpt-4": "openai/gpt-4",
"gpt-4-turbo": "openai/gpt-4-turbo",
"gpt-3.5-turbo": "openai/gpt-3.5-turbo",
"claude-3-haiku": "anthropic/claude-3-haiku",
"claude-3-sonnet": "anthropic/claude-3-sonnet",
"claude-3-opus": "anthropic/claude-3-opus",
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet",
}
case 18: // Anthropic
return map[string]string{
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-opus": "claude-3-opus-20240229",
"claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
}
default:
return map[string]string{}
}
}

func CompletionRatio2JSONString() string {
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
Expand Down
58 changes: 58 additions & 0 deletions relay/controller/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,64 @@ func getMappedModelName(modelName string, mapping map[string]string) (string, bo
return modelName, false
}

// resolveModelAlias resolves standard model names to channel-specific names
func resolveModelAlias(modelName string, channelType int) string {
// Lightweight alias resolution to avoid circular imports
aliasMap := getChannelModelAliases(channelType)

if actualName, exists := aliasMap[modelName]; exists {
return actualName
}

return modelName
}

// getChannelModelAliases returns model aliases for specific channel type
func getChannelModelAliases(channelType int) map[string]string {
switch channelType {
case 24: // OpenRouter
return map[string]string{
"gpt-4o": "openai/gpt-4o",
"gpt-4o-mini": "openai/gpt-4o-mini",
"gpt-4": "openai/gpt-4",
"gpt-4-turbo": "openai/gpt-4-turbo",
"gpt-3.5-turbo": "openai/gpt-3.5-turbo",
"gpt-3.5-turbo-0125": "openai/gpt-3.5-turbo-0125",
"o1": "openai/o1",
"o1-mini": "openai/o1-mini",
"o1-preview": "openai/o1-preview",
"claude-3-haiku": "anthropic/claude-3-haiku",
"claude-3-sonnet": "anthropic/claude-3-sonnet",
"claude-3-opus": "anthropic/claude-3-opus",
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet",
"claude-3.5-haiku": "anthropic/claude-3.5-haiku",
}
case 18: // Anthropic
return map[string]string{
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-opus": "claude-3-opus-20240229",
"claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
"claude-3.5-haiku": "claude-3-5-haiku-20241022",
}
case 28: // Gemini
return map[string]string{
"gemini-pro": "gemini-pro",
"gemini-pro-1.5": "gemini-1.5-pro-latest",
"gemini-flash-1.5": "gemini-1.5-flash-latest",
}
case 33: // Groq
return map[string]string{
"llama-3-8b-instruct": "llama3-8b-8192",
"llama-3-70b-instruct": "llama3-70b-8192",
"llama-3.1-8b-instruct": "llama-3.1-8b-instant",
"llama-3.1-70b-instruct": "llama-3.1-70b-versatile",
}
default:
return map[string]string{}
}
}

func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
if resp == nil {
if meta.ChannelType == channeltype.AwsClaude {
Expand Down
5 changes: 5 additions & 0 deletions relay/controller/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {

// map model name
meta.OriginModelName = textRequest.Model

// First resolve model alias (standard name to channel-specific name)
textRequest.Model = resolveModelAlias(textRequest.Model, meta.ChannelType)

// Then apply channel-specific model mapping (if configured)
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// set system prompt if not empty
Expand Down
Loading
Loading