commit
dd8d7ca9d7
3 changed files with 193 additions and 0 deletions
@ -0,0 +1,17 @@ |
|||||
|
__pycache__/ |
||||
|
*.pyc |
||||
|
*.pyo |
||||
|
*.pyd |
||||
|
.env |
||||
|
*.env |
||||
|
.env.* |
||||
|
*.sqlite3 |
||||
|
*.db |
||||
|
.DS_Store |
||||
|
.vscode/ |
||||
|
.ipynb_checkpoints/ |
||||
|
.cache/ |
||||
|
*.log |
||||
|
*.egg-info/ |
||||
|
dist/ |
||||
|
build/ |
@ -0,0 +1,122 @@ |
|||||
|
# This file may be used to create an environment using: |
||||
|
# $ conda create --name <env> --file <this file> |
||||
|
# platform: win-64 |
||||
|
accelerate=1.6.0=pypi_0 |
||||
|
aiohappyeyeballs=2.6.1=pypi_0 |
||||
|
aiohttp=3.11.18=pypi_0 |
||||
|
aiosignal=1.3.2=pypi_0 |
||||
|
annotated-types=0.7.0=pypi_0 |
||||
|
anyio=4.9.0=pypi_0 |
||||
|
asttokens=3.0.0=pyhd8ed1ab_1 |
||||
|
attrs=25.3.0=pypi_0 |
||||
|
bitsandbytes=0.45.5=pypi_0 |
||||
|
bzip2=1.0.8=h2bbff1b_6 |
||||
|
ca-certificates=2025.4.26=h4c7d964_0 |
||||
|
certifi=2025.4.26=pypi_0 |
||||
|
charset-normalizer=3.4.2=pypi_0 |
||||
|
click=8.2.1=pypi_0 |
||||
|
colorama=0.4.6=pyhd8ed1ab_1 |
||||
|
comm=0.2.2=pyhd8ed1ab_1 |
||||
|
contourpy=1.3.2=pypi_0 |
||||
|
cycler=0.12.1=pypi_0 |
||||
|
datasets=3.5.1=pypi_0 |
||||
|
debugpy=1.8.11=py311h5da7b33_0 |
||||
|
decorator=5.2.1=pyhd8ed1ab_0 |
||||
|
dill=0.3.7=pypi_0 |
||||
|
einops=0.7.0=pypi_0 |
||||
|
exceptiongroup=1.2.2=pyhd8ed1ab_1 |
||||
|
executing=2.2.0=pyhd8ed1ab_0 |
||||
|
fastapi=0.116.0=pypi_0 |
||||
|
filelock=3.18.0=pypi_0 |
||||
|
fonttools=4.57.0=pypi_0 |
||||
|
frozenlist=1.6.0=pypi_0 |
||||
|
fsspec=2023.10.0=pypi_0 |
||||
|
h11=0.16.0=pypi_0 |
||||
|
huggingface-hub=0.30.2=pypi_0 |
||||
|
idna=3.10=pypi_0 |
||||
|
importlib-metadata=8.6.1=pyha770c72_0 |
||||
|
ipykernel=6.29.5=pyh4bbf305_0 |
||||
|
ipython=9.2.0=pyhca29cf9_0 |
||||
|
ipython_pygments_lexers=1.1.1=pyhd8ed1ab_0 |
||||
|
ipywidgets=8.1.7=pypi_0 |
||||
|
jedi=0.19.2=pyhd8ed1ab_1 |
||||
|
jinja2=3.1.6=pypi_0 |
||||
|
jupyter_client=8.6.3=pyhd8ed1ab_1 |
||||
|
jupyter_core=5.7.2=py311h1ea47a8_0 |
||||
|
jupyterlab-widgets=3.0.15=pypi_0 |
||||
|
kiwisolver=1.4.8=pypi_0 |
||||
|
libffi=3.4.4=hd77b12b_1 |
||||
|
libsodium=1.0.18=h8d14728_1 |
||||
|
markupsafe=3.0.2=pypi_0 |
||||
|
matplotlib=3.10.1=pypi_0 |
||||
|
matplotlib-inline=0.1.7=pyhd8ed1ab_1 |
||||
|
mpmath=1.3.0=pypi_0 |
||||
|
multidict=6.4.3=pypi_0 |
||||
|
multiprocess=0.70.15=pypi_0 |
||||
|
nest-asyncio=1.6.0=pyhd8ed1ab_1 |
||||
|
networkx=3.4.2=pypi_0 |
||||
|
numpy=2.2.5=pypi_0 |
||||
|
openssl=3.0.16=h3f729d1_0 |
||||
|
packaging=25.0=pyh29332c3_1 |
||||
|
pandas=2.2.3=pypi_0 |
||||
|
parso=0.8.4=pyhd8ed1ab_1 |
||||
|
peft=0.15.2=pypi_0 |
||||
|
pickleshare=0.7.5=pyhd8ed1ab_1004 |
||||
|
pillow=11.2.1=pypi_0 |
||||
|
pip=25.1=pyhc872135_1 |
||||
|
platformdirs=4.3.7=pyh29332c3_0 |
||||
|
prompt-toolkit=3.0.51=pyha770c72_0 |
||||
|
propcache=0.3.1=pypi_0 |
||||
|
psutil=7.0.0=pypi_0 |
||||
|
pure_eval=0.2.3=pyhd8ed1ab_1 |
||||
|
pyarrow=20.0.0=pypi_0 |
||||
|
pyarrow-hotfix=0.7=pypi_0 |
||||
|
pydantic=2.11.7=pypi_0 |
||||
|
pydantic-core=2.33.2=pypi_0 |
||||
|
pygments=2.19.1=pyhd8ed1ab_0 |
||||
|
pyparsing=3.2.3=pypi_0 |
||||
|
python=3.11.11=h4607a30_0 |
||||
|
python-dateutil=2.9.0.post0=pyhff2d567_1 |
||||
|
python_abi=3.11=2_cp311 |
||||
|
pytz=2025.2=pypi_0 |
||||
|
pywin32=308=py311h5da7b33_0 |
||||
|
pyyaml=6.0.2=pypi_0 |
||||
|
pyzmq=24.0.1=py311h7b3f143_1 |
||||
|
regex=2024.11.6=pypi_0 |
||||
|
requests=2.32.3=pypi_0 |
||||
|
safetensors=0.5.3=pypi_0 |
||||
|
scipy=1.15.2=pypi_0 |
||||
|
seaborn=0.13.2=pypi_0 |
||||
|
setuptools=78.1.1=py311haa95532_0 |
||||
|
six=1.17.0=pyhd8ed1ab_0 |
||||
|
sniffio=1.3.1=pypi_0 |
||||
|
sqlite=3.45.3=h2bbff1b_0 |
||||
|
stack_data=0.6.3=pyhd8ed1ab_1 |
||||
|
starlette=0.46.2=pypi_0 |
||||
|
sympy=1.14.0=pypi_0 |
||||
|
tk=8.6.14=h0416ee5_0 |
||||
|
tokenizers=0.21.1=pypi_0 |
||||
|
torch=2.7.0+cu128=pypi_0 |
||||
|
torchaudio=2.7.0+cu128=pypi_0 |
||||
|
torchvision=0.22.0+cu128=pypi_0 |
||||
|
tornado=6.2=py311ha68e1ae_1 |
||||
|
tqdm=4.67.1=pypi_0 |
||||
|
traitlets=5.14.3=pyhd8ed1ab_1 |
||||
|
transformers=4.51.3=pypi_0 |
||||
|
typing-inspection=0.4.1=pypi_0 |
||||
|
typing_extensions=4.13.2=pyh29332c3_0 |
||||
|
tzdata=2025.2=pypi_0 |
||||
|
ucrt=10.0.22621.0=h57928b3_1 |
||||
|
urllib3=2.4.0=pypi_0 |
||||
|
uvicorn=0.35.0=pypi_0 |
||||
|
vc=14.42=haa95532_5 |
||||
|
vs2015_runtime=14.42.34433=hbfb602d_5 |
||||
|
wcwidth=0.2.13=pyhd8ed1ab_1 |
||||
|
wheel=0.45.1=py311haa95532_0 |
||||
|
widgetsnbextension=4.0.14=pypi_0 |
||||
|
xxhash=3.5.0=pypi_0 |
||||
|
xz=5.6.4=h4754444_1 |
||||
|
yarl=1.20.0=pypi_0 |
||||
|
zeromq=4.3.4=h0e60522_1 |
||||
|
zipp=3.21.0=pyhd8ed1ab_1 |
||||
|
zlib=1.2.13=h8cc25b3_1 |
@ -0,0 +1,54 @@ |
|||||
|
from fastapi import FastAPI, Request |
||||
|
from pydantic import BaseModel |
||||
|
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM |
||||
|
import torch |
||||
|
|
||||
|
## Start the server with this command: uvicorn server:app --reload |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
use_base_model = False |
||||
|
|
||||
|
|
||||
|
if use_base_model: |
||||
|
# Load tokenizer and model (you can set torch_dtype=torch.float16 if on GPU) |
||||
|
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") |
||||
|
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").to("cuda" if torch.cuda.is_available() else "cpu") |
||||
|
else: |
||||
|
# Load custom model from Hugging Face Hub |
||||
|
model_name = "Cailean/macbeth" |
||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu") |
||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
||||
|
|
||||
|
class Query(BaseModel): |
||||
|
speaker: str |
||||
|
sentence: str |
||||
|
emotion: str |
||||
|
temp: float = 0.7 |
||||
|
|
||||
|
@app.post("/test") |
||||
|
def send_text(query: Query): |
||||
|
output = "hi!" |
||||
|
return {"response": output} |
||||
|
|
||||
|
@app.post("/generate") |
||||
|
def generate_text(query: Query): |
||||
|
|
||||
|
speaker = query.speaker |
||||
|
sentence = query.sentence |
||||
|
emotion = query.emotion |
||||
|
temp = query.temp |
||||
|
|
||||
|
prompt = f"""Given the following dialogue from {speaker}: "{sentence}". Generate the next line of dialogue with the approriate speaker that expresses this specific emotion ({emotion}):""" |
||||
|
|
||||
|
print(prompt) |
||||
|
|
||||
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
||||
|
if use_base_model: |
||||
|
outputs = model.generate(**inputs, temperature=temp, do_sample=True) |
||||
|
else: |
||||
|
outputs = model.generate(input_ids=inputs["input_ids"], max_length=64, num_beams=1, early_stopping=True, do_sample=True, temperature=0.8, eos_token_id=tokenizer.eos_token_id) |
||||
|
|
||||
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
||||
|
print(response) |
||||
|
return {"response": response} |
Loading…
Reference in new issue