Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import com.example.appfunctions.agent.data.db.dao.ChatDao
import com.example.appfunctions.agent.data.db.entities.MessageEntity
import com.example.appfunctions.agent.data.db.entities.ThreadEntity

@Database(entities = [ThreadEntity::class, MessageEntity::class], version = 1, exportSchema = false)
@Database(entities = [ThreadEntity::class, MessageEntity::class], version = 2, exportSchema = false)
abstract class AppDatabase : RoomDatabase() {
abstract fun chatDao(): ChatDao
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ data class MessageEntity(
* only non-null if Assistant returned PendingIntent tool response.
*/
val pendingIntentId: String? = null,
val targetPackageName: String? = null,
)

enum class MessageRole {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ abstract class DataModule {
AppDatabase::class.java,
"app_database",
)
.fallbackToDestructiveMigration()
.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ class AgentOrchestrator

try {
val provider = thread.llmModel.providerName
val modelName = thread.llmModel.modelName
val apiKey = getApiKey(provider)
if (apiKey == null) {
completeMessageWithError(
Expand All @@ -108,39 +107,21 @@ class AgentOrchestrator
}

val disconnectedApps = settingsRepository.disconnectedApps.first()
val tools =
getAppFunctionsUseCase().first().values.flatten().filter { metadata ->
metadata.isEnabled && metadata.packageName !in disconnectedApps
}
var previousInteractionId = thread.latestInteractionId
val currentInput = message.textContent
var currentToolOutputs = emptyList<ToolOutput>()
var continueLoop = true

val llmProvider = llmProviderFactory.getProvider(provider)

while (continueLoop) {
val llmInput = prepareLlmInput(currentToolOutputs, currentInput)

currentToolOutputs = emptyList()
val response =
llmProvider.generateResponse(
previousInteractionId = previousInteractionId,
input = llmInput,
tools = tools,
apiKey = apiKey,
modelName = modelName,
)
when (val handleResult = handleLlmResponse(response, message, tools)) {
is HandleResult.Continue -> {
currentToolOutputs = handleResult.toolOutputs
previousInteractionId = handleResult.interactionId
}
is HandleResult.Stop -> {
continueLoop = false
}
}
}
val allTools = getAppFunctionsUseCase().first().values.flatten()

val targetPackageName = message.targetPackageName
val queryText = message.textContent

val tools = filterTools(allTools, disconnectedApps, targetPackageName)

runInteractionLoop(
message = message,
thread = thread,
apiKey = apiKey,
tools = tools,
initialInput = queryText,
)

_status.value = AgentStatus.Idle
} catch (e: Exception) {
Log.e("AgentOrchestrator", "Error processing message", e)
Expand All @@ -153,6 +134,60 @@ class AgentOrchestrator
}
}

private fun filterTools(
allTools: List<AppFunctionMetadata>,
disconnectedApps: Set<String>,
targetPackageName: String?,
): List<AppFunctionMetadata> {
return allTools.filter { metadata ->
metadata.isEnabled &&
metadata.packageName !in disconnectedApps &&
(targetPackageName == null || metadata.packageName == targetPackageName)
}
}

private suspend fun runInteractionLoop(
message: MessageEntity,
thread: ThreadEntity,
apiKey: String,
tools: List<AppFunctionMetadata>,
initialInput: String,
) {
val provider = thread.llmModel.providerName
val modelName = thread.llmModel.modelName
val llmProvider = llmProviderFactory.getProvider(provider)

var previousInteractionId = thread.latestInteractionId
var currentToolOutputs = emptyList<ToolOutput>()
var continueLoop = true
var currentInput = initialInput

while (continueLoop) {
val llmInput = prepareLlmInput(currentToolOutputs, currentInput)

currentToolOutputs = emptyList()
val response =
llmProvider.generateResponse(
previousInteractionId = previousInteractionId,
input = llmInput,
tools = tools,
apiKey = apiKey,
modelName = modelName,
)

when (val handleResult = handleLlmResponse(response, message, tools)) {
is HandleResult.Continue -> {
currentToolOutputs = handleResult.toolOutputs
previousInteractionId = handleResult.interactionId
}

is HandleResult.Stop -> {
continueLoop = false
}
}
}
}

private suspend fun getApiKey(provider: LlmProviderName): String? {
return when (provider) {
LlmProviderName.GEMINI -> settingsRepository.geminiApiKey.first()
Expand Down Expand Up @@ -222,6 +257,7 @@ class AgentOrchestrator
}
HandleResult.Continue(toolResult.toolOutputs, response.interactionId)
}

is ExecuteToolCallsResult.PendingIntentAction -> {
savePendingIntentUseCase(
toolResult.pendingIntentId,
Expand All @@ -236,6 +272,7 @@ class AgentOrchestrator
)
HandleResult.Stop
}

is ExecuteToolCallsResult.Error -> {
HandleResult.Stop
}
Expand All @@ -252,6 +289,7 @@ class AgentOrchestrator
HandleResult.Stop
}
}

is LlmResponse.Error -> {
Log.e("AgentOrchestrator", "LLM Error: ${response.errorMessage}")
completeMessageWithError(message.messageId, message.threadId, response.errorMessage)
Expand Down Expand Up @@ -321,13 +359,15 @@ class AgentOrchestrator
),
)
}

is ExecuteAppFunctionResult.PendingIntentAction -> {
val pendingIntentId = java.util.UUID.randomUUID().toString()
return ExecuteToolCallsResult.PendingIntentAction(
pendingIntentId,
executionResult.pendingIntent,
)
}

is ExecuteAppFunctionResult.Error ->
throw IllegalStateException(
"Tool execution failed for ${toolCall.functionId}: ${executionResult.exception.message}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class SendMessageUseCase
textContent: String,
processingStatus: MessageProcessingStatus,
pendingIntentId: String? = null,
targetPackageName: String? = null,
) {
val message =
MessageEntity(
Expand All @@ -52,6 +53,7 @@ class SendMessageUseCase
timestamp = System.currentTimeMillis(),
processingStatus = processingStatus,
pendingIntentId = pendingIntentId,
targetPackageName = targetPackageName,
)
chatRepository.sendMessage(message)
}
Expand Down
Loading