Skip to content

Commit 67bf91c

Browse files
committed
Read model from Binja settings.
- The model is retreived from user preferences. - If it does not exist, we default to text-davinci-003. - Broke out the functionality that checks for a valid model to a new function for modularity. Relate #14.
1 parent 920ec41 commit 67bf91c

1 file changed

Lines changed: 25 additions & 10 deletions

File tree

src/agent.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pathlib import Path
44

55
import openai
6-
from openai.api_resources.engine import Engine
6+
from openai.api_resources.model import Model
77
from openai.error import APIError
88

99
from binaryninja.lowlevelil import LowLevelILFunction
@@ -45,16 +45,9 @@ def __init__(self,
4545
f'LowLevelILFunction, MediumLevelILFunction, or '
4646
f'HighLevelILFunction, got {type(function)}.')
4747

48-
# Get the list of available engines.
49-
engines: list[Engine] = openai.Engine.list().data
50-
# Ensure the user's selected engine is available.
51-
if engine not in [e.id for e in engines]:
52-
InvalidEngineException(f'Invalid engine: {engine}. Valid engines '
53-
f'are: {[e.id for e in engines]}')
54-
5548
# Set instance attributes.
5649
self.function = function
57-
self.engine = engine
50+
self.model = self.get_model()
5851

5952
def read_api_key(self, filename: Optional[Path]=None) -> str:
6053
'''Checks for the API key in three locations.
@@ -92,6 +85,28 @@ def read_api_key(self, filename: Optional[Path]=None) -> str:
9285
raise APIError('No API key found. Refer to the documentation to add the '
9386
'API key.')
9487

88+
def is_valid_model(self, model: str) -> bool:
89+
'''Checks if the model is valid by querying the OpenAI API.'''
90+
models: list[Model] = openai.Model.list().data
91+
return model in [m.id for m in models]
92+
93+
def get_model(self) -> str:
94+
'''Returns the model that the user has selected from Binary Ninja's
95+
preferences. The default value is set by the OpenAISettings class. If
96+
for some reason the user selected a model that doesn't exist, this
97+
function defaults to 'text-davinci-003'.
98+
'''
99+
settings: Settings = Settings()
100+
# Check that the key exists.
101+
if settings.contains('openai.model'):
102+
# Check that the key is not empty and get the user's selection.
103+
if model := settings.get_string('openai.model'):
104+
# Check that is a valid model by querying the OpenAI API.
105+
if self.is_valid_model(model):
106+
return model
107+
# Return a valid, default model.
108+
assert self.is_valid_model('text-davinci-003')
109+
return 'text-davinci-003'
95110

96111
def instruction_list(self, function: Union[LowLevelILFunction,
97112
MediumLevelILFunction,
@@ -122,7 +137,7 @@ def generate_query(self, function: Union[LowLevelILFunction,
122137
def send_query(self, query: str) -> str:
123138
'''Sends a query to the engine and returns the response.'''
124139
response: str = openai.Completion.create(
125-
model=self.engine,
140+
model=self.model,
126141
prompt=query,
127142
max_tokens=2_048
128143
)

0 commit comments

Comments
 (0)