Gemini Endpoints#
Source agiresearch/AIOS.
# wrapper around gemini from google for LLMs
import re
from .base_llm import BaseLLM
import time
from ...utils.utils import get_from_env
import json
from pyopenagi.utils.chat_template import Response
class GeminiLLM(BaseLLM):
def __init__(self, llm_name: str,
max_gpu_memory: dict = None,
eval_device: str = None,
max_new_tokens: int = 256,
log_mode: str = "console"):
super().__init__(llm_name,
max_gpu_memory,
eval_device,
max_new_tokens,
log_mode)
def load_llm_and_tokenizer(self) -> None:
""" dynamic loading because the module is only needed for this case """
assert re.search(r'gemini', self.model_name, re.IGNORECASE)
try:
import google.generativeai as genai
gemini_api_key = get_from_env("GEMINI_API_KEY")
genai.configure(api_key=gemini_api_key)
self.model = genai.GenerativeModel(self.model_name)
self.tokenizer = None
except ImportError:
raise ImportError(
"Could not import google.generativeai python package. "
"Please install it with `pip install google-generativeai`."
)
def convert_messages(self, messages):
if messages:
gemini_messages = []
for m in messages:
gemini_messages.append(
{
"role": "user" if m["role"] in ["user", "system"] else "model",
"parts": {"text": m["content"]}
}
)
else:
gemini_messages = None
return gemini_messages
def process(self,
agent_request,
temperature=0.0) -> None:
# ensures the model is the current one
""" wrapper around functions"""
agent_request.set_status("executing")
agent_request.set_start_time(time.time())
messages = agent_request.query.messages
tools = agent_request.query.tools
message_return_type = agent_request.query.message_return_type
if tools:
messages = self.tool_calling_input_format(messages, tools)
# convert role to fit the gemini role types
messages = self.convert_messages(
messages=messages,
)
self.logger.log(
f"{agent_request.agent_name} is switched to executing.\n",
level = "executing"
)
outputs = self.model.generate_content(
json.dumps({"contents": messages})
)
try:
result = outputs.candidates[0].content.parts[0].text
if tools:
tool_calls = self.parse_tool_calls(result)
if tool_calls:
# agent_request.set_response(
# Response(
# response_message=None,
# tool_calls=tool_calls
# )
# )
response = Response(
response_message=None,
tool_calls=tool_calls
)
else:
# agent_request.set_response(
# )
response = Response(
response_message=result,
)
else:
if message_return_type == "json":
result = self.parse_json_format(result)
# agent_request.set_response(
# Response(
# response_message=result,
# )
# )
response = Response(
response_message=result,
)
except IndexError:
raise IndexError(f"{self.model_name} can not generate a valid result, please try again")
# agent_request.set_status("done")
# agent_request.set_end_time(time.time())
# return
return response