Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
"timm>1.0.0",
"huggingface_hub[cli]",
"transformers>=4.39.2",
"bitsandbytes",
"autoawq",
"triton >= 3.0.0",
"accelerate",
"einops",
# Meta neural nets
Expand Down
65 changes: 53 additions & 12 deletions src/stretch/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .prompts.ok_robot_prompt import OkRobotPromptBuilder
from .prompts.pickup_prompt import PickupPromptBuilder
from .prompts.simple_prompt import SimpleStretchPromptBuilder
from .qwen_client import Qwen25Client
from .qwen_client import Qwen25Client, get_qwen_variants

# This is a list of all the modules that are imported when you use the import * syntax.
# The __all__ variable is used to define what symbols get exported when from a module when you use the import * syntax.
Expand All @@ -44,11 +44,47 @@


# Add all the various Qwen25 variants
qwen_variants = []
for model_size in ["0.5B", "1.5B", "3B", "7B", "14B", "32B", "72B"]:
for fine_tuning in ["Instruct", "Coder", "Math"]:
qwen_variants.append(f"qwen25-{model_size}-{fine_tuning}")
llms.update({variant: Qwen25Client for variant in qwen_variants})
qwen_variants = get_qwen_variants()
llms.update({variant: Qwen25Client for variant in qwen_variants})


def process_incoming_qwen_types(qwen_type: str):
terms = qwen_type.split("-")
print(terms)
if len(terms) == 1:
# default configuration
model_size, typing_option, finetuning_option, quantization_option = (
"3B",
None,
"Instruct",
"int4",
)
else:
# model type is None = using LM chat
if terms[1] not in ["Math", "Coder", "VL", "Deepseek"]:
typing_option = None
else:
typing_option = terms[1]
terms.remove(terms[1])
# the next item is model size
model_size = terms[1]
# if the quantization is None, meaning no quantization shall be applied
if len(terms) < 3:
finetuning_option, quantization_option = "Instruct", None
# This means finetune with Instruct and using quantization "Instruct-Int4"
elif len(terms) >= 4:
finetuning_option, quantization_option = terms[2], terms[3].lower()
# "AWQ"
elif "awq" in terms[2].lower():
finetuning_option, quantization_option = "Instruct", "awq"
# "Int4"
elif "Instruct" in terms[2]:
finetuning_option, quantization_option = "Instruct", None
else:
finetuning_option, quantization_option = None, terms[2].lower()

return model_size, typing_option, finetuning_option, quantization_option


prompts = {
"simple": SimpleStretchPromptBuilder,
Expand Down Expand Up @@ -102,11 +138,16 @@ def get_llm_client(
return OpenaiClient(prompt, **kwargs)
elif "qwen" in client_type:
# Parse model size and fine-tuning from client_type
terms = client_type.split("-")
if len(terms) == 3:
model_size, fine_tuning = terms[1], terms[2]
else:
model_size, fine_tuning = "3B", "Instruct"
return Qwen25Client(prompt, **kwargs)
model_size, typing_option, fine_tuning, quantization_option = process_incoming_qwen_types(
client_type
)
return Qwen25Client(
prompt,
model_size=model_size,
fine_tuning=fine_tuning,
model_type=typing_option,
quantization=quantization_option,
**kwargs,
)
else:
raise ValueError(f"Invalid client type: {client_type}")
108 changes: 95 additions & 13 deletions src/stretch/llms/qwen_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,127 @@
import timeit
from typing import Any, Dict, Optional, Union

import torch
from termcolor import colored
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

from stretch.llms.base import AbstractLLMClient, AbstractPromptBuilder

# Coder: 32B, 14B, 7B, 3B, 1.5B, 0.5B, (None, Instruct, Instruct-AWQ, Instruct-GGUF, Instruct-GPTQ-Int4, Instruct-GPTQ-Int8)
# Math: 72B, 7B, 1.5B (None, Instruct)
# VL: 72B, 7B, 3B (Instruct, Instruct-AWQ)
# "0.5B", "1.5B", "3B", "7B", "14B", "32B", "72B"
# Deepseek: 1.5B, 7B, 14B, 32B

qwen_typing_options = ["Math", "Coder", "VL", "Deepseek", None]
qwen_quantization_options = {
"VL": [None, "AWQ", "Int4", "Int8", "Instruct", "Instruct-Int4", "Instruct-Int8"],
None: [None, "AWQ", "Int4", "Int8", "Instruct", "Instruct-Int4", "Instruct-Int8"],
"Coder": [None, "AWQ", "Int4", "Int8", "Instruct", "Instruct-Int4", "Instruct-Int8"],
"Math": [None, "Int4", "Int8", "Instruct", "Instruct-Int4", "Instruct-Int8"],
"Deepseek": [None, "Int4", "Int8"],
}
qwen_sizes = {
"VL": ["3B", "7B", "72B"],
None: ["0.5B", "1.5B", "3B", "7B", "14B", "32B", "72B"],
"Coder": ["0.5B", "1.5B", "3B", "7B", "14B", "32B"],
"Math": ["1.5B", "7B", "72B"],
"Deepseek": ["1.5B", "7B", "14B", "72B"],
}


def get_qwen_variants():
qwen_variants = []
for qwen_typing_option in qwen_typing_options:
for qwen_quantization_option in qwen_quantization_options[qwen_typing_option]:
for qwen_size in qwen_sizes[qwen_typing_option]:
qwen_type = "qwen25"
if qwen_typing_option is not None:
qwen_type += "-" + qwen_typing_option
qwen_type += "-" + qwen_size
if qwen_quantization_option is not None:
qwen_type += "-" + qwen_quantization_option
qwen_variants.append(qwen_type)
return qwen_variants


class Qwen25Client(AbstractLLMClient):
def __init__(
self,
prompt: Union[str, AbstractPromptBuilder],
prompt_kwargs: Optional[Dict[str, Any]] = None,
model_size: str = "3B",
fine_tuning: str = "Instruct",
fine_tuning: Optional[str] = "Instruct",
model_type: Optional[str] = None,
max_tokens: int = 4096,
device: str = "cuda",
quantization: Optional[str] = "int4",
):
super().__init__(prompt, prompt_kwargs)
assert device in ["cuda", "mps"], f"Invalid device: {device}"
assert model_size in [
"0.5B",
"1.5B",
"3B",
"7B",
"14B",
"32B",
"72B",
], f"Invalid model size: {model_size}"
assert fine_tuning in ["Instruct", "Coder", "Math"], f"Invalid fine-tuning: {fine_tuning}"
assert model_type in qwen_typing_options, f"Invalid model type: {model_type}"
assert model_size in qwen_sizes[model_type], f"Invalid model size: {model_size}"
assert fine_tuning in [None, "Instruct"], f"Invalid fine-tuning: {fine_tuning}"

self.max_tokens = max_tokens
model_name = f"Qwen/Qwen2.5-{model_size}-{fine_tuning}"

if model_type == "Deepseek":
model_name = f"deepseek-ai/DeepSeek-R1-Distill-Qwen-{model_size}"
elif model_type is None:
if fine_tuning is None:
model_name = f"Qwen/Qwen2.5-{model_size}"
else:
model_name = f"Qwen/Qwen2.5-{model_size}-{fine_tuning}"
else:
if fine_tuning is None:
model_name = f"Qwen/Qwen2.5-{model_type}-{model_size}"
else:
model_name = f"Qwen/Qwen2.5-{model_type}-{model_size}-{fine_tuning}"

print(f"Loading model: {model_name}")
model_kwargs = {"torch_dtype": "auto"}

quantization_config = None
if quantization is not None:
quantization = quantization.lower()
# Note: there were supposed to be other options but this is the only one that worked this way
if quantization == "awq":
model_kwargs["torch_dtype"] = torch.float16
model_name += "-AWQ"
elif quantization in ["int8", "int4"]:
try:
import bitsandbytes # noqa: F401
from transformers import BitsAndBytesConfig
except ImportError:
raise ImportError(
"bitsandbytes required for int4/int8 quantization: pip install bitsandbytes"
)

quantization_config = BitsAndBytesConfig(
load_in_4bit=(quantization == "int4"),
load_in_8bit=(quantization == "int8"),
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model_kwargs["quantization_config"] = quantization_config
else:
raise ValueError(f"Unknown quantization method: {quantization}")

if quantization_config is not None:
model_kwargs["quantization_config"] = quantization_config

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
)
self.pipe = pipeline(
"text-generation", model=self.model, tokenizer=self.tokenizer, device=device
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=device,
model_kwargs=model_kwargs,
)

def __call__(self, command: str, verbose: bool = False):
Expand Down
2 changes: 1 addition & 1 deletion src/stretch/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# 2) we can import it in setup.py for the same reason
# 3) we can import it into your module

__version__ = "0.3.0"
__version__ = "0.3.1"

__stretchpy_protocol__ = "spp0"

Expand Down