a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

51 lines
1.7 KiB

5 months ago
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, set_seed
import random
import torch
class TextGenerator:
def __init__(self):
self.model = None
self.tokenizer = None
#self.load_models()
def load_models(self):
print('Loading Models...')
self.tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
self.model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
print('Models Loaded!')
def generate_text(self):
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
device_map="cuda",
torch_dtype="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
messages = [
{"role": "system", "content": "You are a helpful AI assistant, that generates two nouns and returns one sentence in the format of: a (noun) with a (noun).\n You can descirbe a random object typically found in a bin"},
{"role": "user", "content": "Can you provide me with a sentence"},
]
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
generation_args = {
"max_new_tokens": 50, # Reduced to focus on concise output
"return_full_text": False,
"temperature": 0.7, # Adjusted for more randomness
"do_sample": True,
"top_k": 100, # Top-k sampling
"top_p": 1, # Nucleus sampling
}
output = pipe(messages, **generation_args)
return output[0]['generated_text']