commit
dd8d7ca9d7
3 changed files with 193 additions and 0 deletions
@ -0,0 +1,17 @@ |
|||
__pycache__/ |
|||
*.pyc |
|||
*.pyo |
|||
*.pyd |
|||
.env |
|||
*.env |
|||
.env.* |
|||
*.sqlite3 |
|||
*.db |
|||
.DS_Store |
|||
.vscode/ |
|||
.ipynb_checkpoints/ |
|||
.cache/ |
|||
*.log |
|||
*.egg-info/ |
|||
dist/ |
|||
build/ |
@ -0,0 +1,122 @@ |
|||
# This file may be used to create an environment using: |
|||
# $ conda create --name <env> --file <this file> |
|||
# platform: win-64 |
|||
accelerate=1.6.0=pypi_0 |
|||
aiohappyeyeballs=2.6.1=pypi_0 |
|||
aiohttp=3.11.18=pypi_0 |
|||
aiosignal=1.3.2=pypi_0 |
|||
annotated-types=0.7.0=pypi_0 |
|||
anyio=4.9.0=pypi_0 |
|||
asttokens=3.0.0=pyhd8ed1ab_1 |
|||
attrs=25.3.0=pypi_0 |
|||
bitsandbytes=0.45.5=pypi_0 |
|||
bzip2=1.0.8=h2bbff1b_6 |
|||
ca-certificates=2025.4.26=h4c7d964_0 |
|||
certifi=2025.4.26=pypi_0 |
|||
charset-normalizer=3.4.2=pypi_0 |
|||
click=8.2.1=pypi_0 |
|||
colorama=0.4.6=pyhd8ed1ab_1 |
|||
comm=0.2.2=pyhd8ed1ab_1 |
|||
contourpy=1.3.2=pypi_0 |
|||
cycler=0.12.1=pypi_0 |
|||
datasets=3.5.1=pypi_0 |
|||
debugpy=1.8.11=py311h5da7b33_0 |
|||
decorator=5.2.1=pyhd8ed1ab_0 |
|||
dill=0.3.7=pypi_0 |
|||
einops=0.7.0=pypi_0 |
|||
exceptiongroup=1.2.2=pyhd8ed1ab_1 |
|||
executing=2.2.0=pyhd8ed1ab_0 |
|||
fastapi=0.116.0=pypi_0 |
|||
filelock=3.18.0=pypi_0 |
|||
fonttools=4.57.0=pypi_0 |
|||
frozenlist=1.6.0=pypi_0 |
|||
fsspec=2023.10.0=pypi_0 |
|||
h11=0.16.0=pypi_0 |
|||
huggingface-hub=0.30.2=pypi_0 |
|||
idna=3.10=pypi_0 |
|||
importlib-metadata=8.6.1=pyha770c72_0 |
|||
ipykernel=6.29.5=pyh4bbf305_0 |
|||
ipython=9.2.0=pyhca29cf9_0 |
|||
ipython_pygments_lexers=1.1.1=pyhd8ed1ab_0 |
|||
ipywidgets=8.1.7=pypi_0 |
|||
jedi=0.19.2=pyhd8ed1ab_1 |
|||
jinja2=3.1.6=pypi_0 |
|||
jupyter_client=8.6.3=pyhd8ed1ab_1 |
|||
jupyter_core=5.7.2=py311h1ea47a8_0 |
|||
jupyterlab-widgets=3.0.15=pypi_0 |
|||
kiwisolver=1.4.8=pypi_0 |
|||
libffi=3.4.4=hd77b12b_1 |
|||
libsodium=1.0.18=h8d14728_1 |
|||
markupsafe=3.0.2=pypi_0 |
|||
matplotlib=3.10.1=pypi_0 |
|||
matplotlib-inline=0.1.7=pyhd8ed1ab_1 |
|||
mpmath=1.3.0=pypi_0 |
|||
multidict=6.4.3=pypi_0 |
|||
multiprocess=0.70.15=pypi_0 |
|||
nest-asyncio=1.6.0=pyhd8ed1ab_1 |
|||
networkx=3.4.2=pypi_0 |
|||
numpy=2.2.5=pypi_0 |
|||
openssl=3.0.16=h3f729d1_0 |
|||
packaging=25.0=pyh29332c3_1 |
|||
pandas=2.2.3=pypi_0 |
|||
parso=0.8.4=pyhd8ed1ab_1 |
|||
peft=0.15.2=pypi_0 |
|||
pickleshare=0.7.5=pyhd8ed1ab_1004 |
|||
pillow=11.2.1=pypi_0 |
|||
pip=25.1=pyhc872135_1 |
|||
platformdirs=4.3.7=pyh29332c3_0 |
|||
prompt-toolkit=3.0.51=pyha770c72_0 |
|||
propcache=0.3.1=pypi_0 |
|||
psutil=7.0.0=pypi_0 |
|||
pure_eval=0.2.3=pyhd8ed1ab_1 |
|||
pyarrow=20.0.0=pypi_0 |
|||
pyarrow-hotfix=0.7=pypi_0 |
|||
pydantic=2.11.7=pypi_0 |
|||
pydantic-core=2.33.2=pypi_0 |
|||
pygments=2.19.1=pyhd8ed1ab_0 |
|||
pyparsing=3.2.3=pypi_0 |
|||
python=3.11.11=h4607a30_0 |
|||
python-dateutil=2.9.0.post0=pyhff2d567_1 |
|||
python_abi=3.11=2_cp311 |
|||
pytz=2025.2=pypi_0 |
|||
pywin32=308=py311h5da7b33_0 |
|||
pyyaml=6.0.2=pypi_0 |
|||
pyzmq=24.0.1=py311h7b3f143_1 |
|||
regex=2024.11.6=pypi_0 |
|||
requests=2.32.3=pypi_0 |
|||
safetensors=0.5.3=pypi_0 |
|||
scipy=1.15.2=pypi_0 |
|||
seaborn=0.13.2=pypi_0 |
|||
setuptools=78.1.1=py311haa95532_0 |
|||
six=1.17.0=pyhd8ed1ab_0 |
|||
sniffio=1.3.1=pypi_0 |
|||
sqlite=3.45.3=h2bbff1b_0 |
|||
stack_data=0.6.3=pyhd8ed1ab_1 |
|||
starlette=0.46.2=pypi_0 |
|||
sympy=1.14.0=pypi_0 |
|||
tk=8.6.14=h0416ee5_0 |
|||
tokenizers=0.21.1=pypi_0 |
|||
torch=2.7.0+cu128=pypi_0 |
|||
torchaudio=2.7.0+cu128=pypi_0 |
|||
torchvision=0.22.0+cu128=pypi_0 |
|||
tornado=6.2=py311ha68e1ae_1 |
|||
tqdm=4.67.1=pypi_0 |
|||
traitlets=5.14.3=pyhd8ed1ab_1 |
|||
transformers=4.51.3=pypi_0 |
|||
typing-inspection=0.4.1=pypi_0 |
|||
typing_extensions=4.13.2=pyh29332c3_0 |
|||
tzdata=2025.2=pypi_0 |
|||
ucrt=10.0.22621.0=h57928b3_1 |
|||
urllib3=2.4.0=pypi_0 |
|||
uvicorn=0.35.0=pypi_0 |
|||
vc=14.42=haa95532_5 |
|||
vs2015_runtime=14.42.34433=hbfb602d_5 |
|||
wcwidth=0.2.13=pyhd8ed1ab_1 |
|||
wheel=0.45.1=py311haa95532_0 |
|||
widgetsnbextension=4.0.14=pypi_0 |
|||
xxhash=3.5.0=pypi_0 |
|||
xz=5.6.4=h4754444_1 |
|||
yarl=1.20.0=pypi_0 |
|||
zeromq=4.3.4=h0e60522_1 |
|||
zipp=3.21.0=pyhd8ed1ab_1 |
|||
zlib=1.2.13=h8cc25b3_1 |
@ -0,0 +1,54 @@ |
|||
from fastapi import FastAPI, Request |
|||
from pydantic import BaseModel |
|||
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM |
|||
import torch |
|||
|
|||
## Start the server with this command: uvicorn server:app --reload |
|||
|
|||
app = FastAPI() |
|||
|
|||
use_base_model = False |
|||
|
|||
|
|||
if use_base_model: |
|||
# Load tokenizer and model (you can set torch_dtype=torch.float16 if on GPU) |
|||
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") |
|||
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").to("cuda" if torch.cuda.is_available() else "cpu") |
|||
else: |
|||
# Load custom model from Hugging Face Hub |
|||
model_name = "Cailean/macbeth" |
|||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu") |
|||
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|||
|
|||
class Query(BaseModel): |
|||
speaker: str |
|||
sentence: str |
|||
emotion: str |
|||
temp: float = 0.7 |
|||
|
|||
@app.post("/test") |
|||
def send_text(query: Query): |
|||
output = "hi!" |
|||
return {"response": output} |
|||
|
|||
@app.post("/generate") |
|||
def generate_text(query: Query): |
|||
|
|||
speaker = query.speaker |
|||
sentence = query.sentence |
|||
emotion = query.emotion |
|||
temp = query.temp |
|||
|
|||
prompt = f"""Given the following dialogue from {speaker}: "{sentence}". Generate the next line of dialogue with the approriate speaker that expresses this specific emotion ({emotion}):""" |
|||
|
|||
print(prompt) |
|||
|
|||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|||
if use_base_model: |
|||
outputs = model.generate(**inputs, temperature=temp, do_sample=True) |
|||
else: |
|||
outputs = model.generate(input_ids=inputs["input_ids"], max_length=64, num_beams=1, early_stopping=True, do_sample=True, temperature=0.8, eos_token_id=tokenizer.eos_token_id) |
|||
|
|||
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|||
print(response) |
|||
return {"response": response} |
Loading…
Reference in new issue