|
3 | 3 | from pathlib import Path |
4 | 4 |
|
5 | 5 | import openai |
6 | | -from openai.api_resources.engine import Engine |
| 6 | +from openai.api_resources.model import Model |
7 | 7 | from openai.error import APIError |
8 | 8 |
|
9 | 9 | from binaryninja.lowlevelil import LowLevelILFunction |
@@ -45,16 +45,9 @@ def __init__(self, |
45 | 45 | f'LowLevelILFunction, MediumLevelILFunction, or ' |
46 | 46 | f'HighLevelILFunction, got {type(function)}.') |
47 | 47 |
|
48 | | - # Get the list of available engines. |
49 | | - engines: list[Engine] = openai.Engine.list().data |
50 | | - # Ensure the user's selected engine is available. |
51 | | - if engine not in [e.id for e in engines]: |
52 | | - InvalidEngineException(f'Invalid engine: {engine}. Valid engines ' |
53 | | - f'are: {[e.id for e in engines]}') |
54 | | - |
55 | 48 | # Set instance attributes. |
56 | 49 | self.function = function |
57 | | - self.engine = engine |
| 50 | + self.model = self.get_model() |
58 | 51 |
|
59 | 52 | def read_api_key(self, filename: Optional[Path]=None) -> str: |
60 | 53 | '''Checks for the API key in three locations. |
@@ -92,6 +85,28 @@ def read_api_key(self, filename: Optional[Path]=None) -> str: |
92 | 85 | raise APIError('No API key found. Refer to the documentation to add the ' |
93 | 86 | 'API key.') |
94 | 87 |
|
| 88 | + def is_valid_model(self, model: str) -> bool: |
| 89 | + '''Checks if the model is valid by querying the OpenAI API.''' |
| 90 | + models: list[Model] = openai.Model.list().data |
| 91 | + return model in [m.id for m in models] |
| 92 | + |
| 93 | + def get_model(self) -> str: |
| 94 | + '''Returns the model that the user has selected from Binary Ninja's |
| 95 | + preferences. The default value is set by the OpenAISettings class. If |
| 96 | + for some reason the user selected a model that doesn't exist, this |
| 97 | + function defaults to 'text-davinci-003'. |
| 98 | + ''' |
| 99 | + settings: Settings = Settings() |
| 100 | + # Check that the key exists. |
| 101 | + if settings.contains('openai.model'): |
| 102 | + # Check that the key is not empty and get the user's selection. |
| 103 | + if model := settings.get_string('openai.model'): |
| 104 | + # Check that is a valid model by querying the OpenAI API. |
| 105 | + if self.is_valid_model(model): |
| 106 | + return model |
| 107 | + # Return a valid, default model. |
| 108 | + assert self.is_valid_model('text-davinci-003') |
| 109 | + return 'text-davinci-003' |
95 | 110 |
|
96 | 111 | def instruction_list(self, function: Union[LowLevelILFunction, |
97 | 112 | MediumLevelILFunction, |
@@ -122,7 +137,7 @@ def generate_query(self, function: Union[LowLevelILFunction, |
122 | 137 | def send_query(self, query: str) -> str: |
123 | 138 | '''Sends a query to the engine and returns the response.''' |
124 | 139 | response: str = openai.Completion.create( |
125 | | - model=self.engine, |
| 140 | + model=self.model, |
126 | 141 | prompt=query, |
127 | 142 | max_tokens=2_048 |
128 | 143 | ) |
|
0 commit comments