Browse Source

initial commit

main
cailean 3 days ago
commit
dd8d7ca9d7
  1. 17
      .gitignore
  2. 122
      requirements.txt
  3. 54
      server.py

17
.gitignore

@ -0,0 +1,17 @@
__pycache__/
*.pyc
*.pyo
*.pyd
.env
*.env
.env.*
*.sqlite3
*.db
.DS_Store
.vscode/
.ipynb_checkpoints/
.cache/
*.log
*.egg-info/
dist/
build/

122
requirements.txt

@ -0,0 +1,122 @@
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: win-64
accelerate=1.6.0=pypi_0
aiohappyeyeballs=2.6.1=pypi_0
aiohttp=3.11.18=pypi_0
aiosignal=1.3.2=pypi_0
annotated-types=0.7.0=pypi_0
anyio=4.9.0=pypi_0
asttokens=3.0.0=pyhd8ed1ab_1
attrs=25.3.0=pypi_0
bitsandbytes=0.45.5=pypi_0
bzip2=1.0.8=h2bbff1b_6
ca-certificates=2025.4.26=h4c7d964_0
certifi=2025.4.26=pypi_0
charset-normalizer=3.4.2=pypi_0
click=8.2.1=pypi_0
colorama=0.4.6=pyhd8ed1ab_1
comm=0.2.2=pyhd8ed1ab_1
contourpy=1.3.2=pypi_0
cycler=0.12.1=pypi_0
datasets=3.5.1=pypi_0
debugpy=1.8.11=py311h5da7b33_0
decorator=5.2.1=pyhd8ed1ab_0
dill=0.3.7=pypi_0
einops=0.7.0=pypi_0
exceptiongroup=1.2.2=pyhd8ed1ab_1
executing=2.2.0=pyhd8ed1ab_0
fastapi=0.116.0=pypi_0
filelock=3.18.0=pypi_0
fonttools=4.57.0=pypi_0
frozenlist=1.6.0=pypi_0
fsspec=2023.10.0=pypi_0
h11=0.16.0=pypi_0
huggingface-hub=0.30.2=pypi_0
idna=3.10=pypi_0
importlib-metadata=8.6.1=pyha770c72_0
ipykernel=6.29.5=pyh4bbf305_0
ipython=9.2.0=pyhca29cf9_0
ipython_pygments_lexers=1.1.1=pyhd8ed1ab_0
ipywidgets=8.1.7=pypi_0
jedi=0.19.2=pyhd8ed1ab_1
jinja2=3.1.6=pypi_0
jupyter_client=8.6.3=pyhd8ed1ab_1
jupyter_core=5.7.2=py311h1ea47a8_0
jupyterlab-widgets=3.0.15=pypi_0
kiwisolver=1.4.8=pypi_0
libffi=3.4.4=hd77b12b_1
libsodium=1.0.18=h8d14728_1
markupsafe=3.0.2=pypi_0
matplotlib=3.10.1=pypi_0
matplotlib-inline=0.1.7=pyhd8ed1ab_1
mpmath=1.3.0=pypi_0
multidict=6.4.3=pypi_0
multiprocess=0.70.15=pypi_0
nest-asyncio=1.6.0=pyhd8ed1ab_1
networkx=3.4.2=pypi_0
numpy=2.2.5=pypi_0
openssl=3.0.16=h3f729d1_0
packaging=25.0=pyh29332c3_1
pandas=2.2.3=pypi_0
parso=0.8.4=pyhd8ed1ab_1
peft=0.15.2=pypi_0
pickleshare=0.7.5=pyhd8ed1ab_1004
pillow=11.2.1=pypi_0
pip=25.1=pyhc872135_1
platformdirs=4.3.7=pyh29332c3_0
prompt-toolkit=3.0.51=pyha770c72_0
propcache=0.3.1=pypi_0
psutil=7.0.0=pypi_0
pure_eval=0.2.3=pyhd8ed1ab_1
pyarrow=20.0.0=pypi_0
pyarrow-hotfix=0.7=pypi_0
pydantic=2.11.7=pypi_0
pydantic-core=2.33.2=pypi_0
pygments=2.19.1=pyhd8ed1ab_0
pyparsing=3.2.3=pypi_0
python=3.11.11=h4607a30_0
python-dateutil=2.9.0.post0=pyhff2d567_1
python_abi=3.11=2_cp311
pytz=2025.2=pypi_0
pywin32=308=py311h5da7b33_0
pyyaml=6.0.2=pypi_0
pyzmq=24.0.1=py311h7b3f143_1
regex=2024.11.6=pypi_0
requests=2.32.3=pypi_0
safetensors=0.5.3=pypi_0
scipy=1.15.2=pypi_0
seaborn=0.13.2=pypi_0
setuptools=78.1.1=py311haa95532_0
six=1.17.0=pyhd8ed1ab_0
sniffio=1.3.1=pypi_0
sqlite=3.45.3=h2bbff1b_0
stack_data=0.6.3=pyhd8ed1ab_1
starlette=0.46.2=pypi_0
sympy=1.14.0=pypi_0
tk=8.6.14=h0416ee5_0
tokenizers=0.21.1=pypi_0
torch=2.7.0+cu128=pypi_0
torchaudio=2.7.0+cu128=pypi_0
torchvision=0.22.0+cu128=pypi_0
tornado=6.2=py311ha68e1ae_1
tqdm=4.67.1=pypi_0
traitlets=5.14.3=pyhd8ed1ab_1
transformers=4.51.3=pypi_0
typing-inspection=0.4.1=pypi_0
typing_extensions=4.13.2=pyh29332c3_0
tzdata=2025.2=pypi_0
ucrt=10.0.22621.0=h57928b3_1
urllib3=2.4.0=pypi_0
uvicorn=0.35.0=pypi_0
vc=14.42=haa95532_5
vs2015_runtime=14.42.34433=hbfb602d_5
wcwidth=0.2.13=pyhd8ed1ab_1
wheel=0.45.1=py311haa95532_0
widgetsnbextension=4.0.14=pypi_0
xxhash=3.5.0=pypi_0
xz=5.6.4=h4754444_1
yarl=1.20.0=pypi_0
zeromq=4.3.4=h0e60522_1
zipp=3.21.0=pyhd8ed1ab_1
zlib=1.2.13=h8cc25b3_1

54
server.py

@ -0,0 +1,54 @@
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM
import torch
## Start the server with this command: uvicorn server:app --reload
app = FastAPI()
use_base_model = False
if use_base_model:
# Load tokenizer and model (you can set torch_dtype=torch.float16 if on GPU)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").to("cuda" if torch.cuda.is_available() else "cpu")
else:
# Load custom model from Hugging Face Hub
model_name = "Cailean/macbeth"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
class Query(BaseModel):
speaker: str
sentence: str
emotion: str
temp: float = 0.7
@app.post("/test")
def send_text(query: Query):
output = "hi!"
return {"response": output}
@app.post("/generate")
def generate_text(query: Query):
speaker = query.speaker
sentence = query.sentence
emotion = query.emotion
temp = query.temp
prompt = f"""Given the following dialogue from {speaker}: "{sentence}". Generate the next line of dialogue with the approriate speaker that expresses this specific emotion ({emotion}):"""
print(prompt)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
if use_base_model:
outputs = model.generate(**inputs, temperature=temp, do_sample=True)
else:
outputs = model.generate(input_ids=inputs["input_ids"], max_length=64, num_beams=1, early_stopping=True, do_sample=True, temperature=0.8, eos_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
return {"response": response}
Loading…
Cancel
Save