FastAPI server for LLM intergration with Macbeth CV
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

57 lines
1.9 KiB

from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM
import torch
from peft import PeftModel
## Start the server with this command: uvicorn server:app --reload
app = FastAPI()
use_base_model = False
if use_base_model:
# Load tokenizer and model (you can set torch_dtype=torch.float16 if on GPU)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").to("cuda" if torch.cuda.is_available() else "cpu")
else:
# Load custom model from Hugging Face Hub
model_name = "Cailean/MacbethPEFT"
base_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
model = PeftModel.from_pretrained(base_model, model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
class Query(BaseModel):
speaker: str
sentence: str
emotion: str
temp: float = 0.7
@app.post("/test")
def send_text(query: Query):
output = "hi!"
return {"response": output}
@app.post("/generate")
def generate_text(query: Query):
speaker = query.speaker
sentence = query.sentence
emotion = query.emotion
temp = query.temp
prompt = f"""Given the following dialogue from {speaker}: "{sentence}". Generate the next line of dialogue with the approriate speaker that expresses this specific emotion ({emotion}):"""
print(prompt)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
if use_base_model:
outputs = model.generate(**inputs, temperature=temp, do_sample=True)
else:
outputs = model.generate(input_ids=inputs["input_ids"], max_length=64, num_beams=1, early_stopping=True, do_sample=True, temperature=0.8, eos_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
return {"response": response}