You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
1.9 KiB
54 lines
1.9 KiB
from fastapi import FastAPI, Request
|
|
from pydantic import BaseModel
|
|
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM
|
|
import torch
|
|
|
|
## Start the server with this command: uvicorn server:app --reload
|
|
|
|
app = FastAPI()
|
|
|
|
use_base_model = False
|
|
|
|
|
|
if use_base_model:
|
|
# Load tokenizer and model (you can set torch_dtype=torch.float16 if on GPU)
|
|
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
|
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").to("cuda" if torch.cuda.is_available() else "cpu")
|
|
else:
|
|
# Load custom model from Hugging Face Hub
|
|
model_name = "Cailean/macbeth"
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
class Query(BaseModel):
|
|
speaker: str
|
|
sentence: str
|
|
emotion: str
|
|
temp: float = 0.7
|
|
|
|
@app.post("/test")
|
|
def send_text(query: Query):
|
|
output = "hi!"
|
|
return {"response": output}
|
|
|
|
@app.post("/generate")
|
|
def generate_text(query: Query):
|
|
|
|
speaker = query.speaker
|
|
sentence = query.sentence
|
|
emotion = query.emotion
|
|
temp = query.temp
|
|
|
|
prompt = f"""Given the following dialogue from {speaker}: "{sentence}". Generate the next line of dialogue with the approriate speaker that expresses this specific emotion ({emotion}):"""
|
|
|
|
print(prompt)
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
if use_base_model:
|
|
outputs = model.generate(**inputs, temperature=temp, do_sample=True)
|
|
else:
|
|
outputs = model.generate(input_ids=inputs["input_ids"], max_length=64, num_beams=1, early_stopping=True, do_sample=True, temperature=0.8, eos_token_id=tokenizer.eos_token_id)
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
print(response)
|
|
return {"response": response}
|