commit dd8d7ca9d720c95a43331068242a8a3b5e759f4f Author: cailean Date: Wed Jul 16 13:10:09 2025 +0100 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..053a29d --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +__pycache__/ +*.pyc +*.pyo +*.pyd +.env +*.env +.env.* +*.sqlite3 +*.db +.DS_Store +.vscode/ +.ipynb_checkpoints/ +.cache/ +*.log +*.egg-info/ +dist/ +build/ \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..add6beb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,122 @@ +# This file may be used to create an environment using: +# $ conda create --name --file +# platform: win-64 +accelerate=1.6.0=pypi_0 +aiohappyeyeballs=2.6.1=pypi_0 +aiohttp=3.11.18=pypi_0 +aiosignal=1.3.2=pypi_0 +annotated-types=0.7.0=pypi_0 +anyio=4.9.0=pypi_0 +asttokens=3.0.0=pyhd8ed1ab_1 +attrs=25.3.0=pypi_0 +bitsandbytes=0.45.5=pypi_0 +bzip2=1.0.8=h2bbff1b_6 +ca-certificates=2025.4.26=h4c7d964_0 +certifi=2025.4.26=pypi_0 +charset-normalizer=3.4.2=pypi_0 +click=8.2.1=pypi_0 +colorama=0.4.6=pyhd8ed1ab_1 +comm=0.2.2=pyhd8ed1ab_1 +contourpy=1.3.2=pypi_0 +cycler=0.12.1=pypi_0 +datasets=3.5.1=pypi_0 +debugpy=1.8.11=py311h5da7b33_0 +decorator=5.2.1=pyhd8ed1ab_0 +dill=0.3.7=pypi_0 +einops=0.7.0=pypi_0 +exceptiongroup=1.2.2=pyhd8ed1ab_1 +executing=2.2.0=pyhd8ed1ab_0 +fastapi=0.116.0=pypi_0 +filelock=3.18.0=pypi_0 +fonttools=4.57.0=pypi_0 +frozenlist=1.6.0=pypi_0 +fsspec=2023.10.0=pypi_0 +h11=0.16.0=pypi_0 +huggingface-hub=0.30.2=pypi_0 +idna=3.10=pypi_0 +importlib-metadata=8.6.1=pyha770c72_0 +ipykernel=6.29.5=pyh4bbf305_0 +ipython=9.2.0=pyhca29cf9_0 +ipython_pygments_lexers=1.1.1=pyhd8ed1ab_0 +ipywidgets=8.1.7=pypi_0 +jedi=0.19.2=pyhd8ed1ab_1 +jinja2=3.1.6=pypi_0 +jupyter_client=8.6.3=pyhd8ed1ab_1 +jupyter_core=5.7.2=py311h1ea47a8_0 +jupyterlab-widgets=3.0.15=pypi_0 +kiwisolver=1.4.8=pypi_0 +libffi=3.4.4=hd77b12b_1 +libsodium=1.0.18=h8d14728_1 +markupsafe=3.0.2=pypi_0 +matplotlib=3.10.1=pypi_0 +matplotlib-inline=0.1.7=pyhd8ed1ab_1 +mpmath=1.3.0=pypi_0 +multidict=6.4.3=pypi_0 +multiprocess=0.70.15=pypi_0 +nest-asyncio=1.6.0=pyhd8ed1ab_1 +networkx=3.4.2=pypi_0 +numpy=2.2.5=pypi_0 +openssl=3.0.16=h3f729d1_0 +packaging=25.0=pyh29332c3_1 +pandas=2.2.3=pypi_0 +parso=0.8.4=pyhd8ed1ab_1 +peft=0.15.2=pypi_0 +pickleshare=0.7.5=pyhd8ed1ab_1004 +pillow=11.2.1=pypi_0 +pip=25.1=pyhc872135_1 +platformdirs=4.3.7=pyh29332c3_0 +prompt-toolkit=3.0.51=pyha770c72_0 +propcache=0.3.1=pypi_0 +psutil=7.0.0=pypi_0 +pure_eval=0.2.3=pyhd8ed1ab_1 +pyarrow=20.0.0=pypi_0 +pyarrow-hotfix=0.7=pypi_0 +pydantic=2.11.7=pypi_0 +pydantic-core=2.33.2=pypi_0 +pygments=2.19.1=pyhd8ed1ab_0 +pyparsing=3.2.3=pypi_0 +python=3.11.11=h4607a30_0 +python-dateutil=2.9.0.post0=pyhff2d567_1 +python_abi=3.11=2_cp311 +pytz=2025.2=pypi_0 +pywin32=308=py311h5da7b33_0 +pyyaml=6.0.2=pypi_0 +pyzmq=24.0.1=py311h7b3f143_1 +regex=2024.11.6=pypi_0 +requests=2.32.3=pypi_0 +safetensors=0.5.3=pypi_0 +scipy=1.15.2=pypi_0 +seaborn=0.13.2=pypi_0 +setuptools=78.1.1=py311haa95532_0 +six=1.17.0=pyhd8ed1ab_0 +sniffio=1.3.1=pypi_0 +sqlite=3.45.3=h2bbff1b_0 +stack_data=0.6.3=pyhd8ed1ab_1 +starlette=0.46.2=pypi_0 +sympy=1.14.0=pypi_0 +tk=8.6.14=h0416ee5_0 +tokenizers=0.21.1=pypi_0 +torch=2.7.0+cu128=pypi_0 +torchaudio=2.7.0+cu128=pypi_0 +torchvision=0.22.0+cu128=pypi_0 +tornado=6.2=py311ha68e1ae_1 +tqdm=4.67.1=pypi_0 +traitlets=5.14.3=pyhd8ed1ab_1 +transformers=4.51.3=pypi_0 +typing-inspection=0.4.1=pypi_0 +typing_extensions=4.13.2=pyh29332c3_0 +tzdata=2025.2=pypi_0 +ucrt=10.0.22621.0=h57928b3_1 +urllib3=2.4.0=pypi_0 +uvicorn=0.35.0=pypi_0 +vc=14.42=haa95532_5 +vs2015_runtime=14.42.34433=hbfb602d_5 +wcwidth=0.2.13=pyhd8ed1ab_1 +wheel=0.45.1=py311haa95532_0 +widgetsnbextension=4.0.14=pypi_0 +xxhash=3.5.0=pypi_0 +xz=5.6.4=h4754444_1 +yarl=1.20.0=pypi_0 +zeromq=4.3.4=h0e60522_1 +zipp=3.21.0=pyhd8ed1ab_1 +zlib=1.2.13=h8cc25b3_1 diff --git a/server.py b/server.py new file mode 100644 index 0000000..d821978 --- /dev/null +++ b/server.py @@ -0,0 +1,54 @@ +from fastapi import FastAPI, Request +from pydantic import BaseModel +from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM +import torch + +## Start the server with this command: uvicorn server:app --reload + +app = FastAPI() + +use_base_model = False + + +if use_base_model: + # Load tokenizer and model (you can set torch_dtype=torch.float16 if on GPU) + tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") + model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").to("cuda" if torch.cuda.is_available() else "cpu") +else: + # Load custom model from Hugging Face Hub + model_name = "Cailean/macbeth" + model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu") + tokenizer = AutoTokenizer.from_pretrained(model_name) + +class Query(BaseModel): + speaker: str + sentence: str + emotion: str + temp: float = 0.7 + +@app.post("/test") +def send_text(query: Query): + output = "hi!" + return {"response": output} + +@app.post("/generate") +def generate_text(query: Query): + + speaker = query.speaker + sentence = query.sentence + emotion = query.emotion + temp = query.temp + + prompt = f"""Given the following dialogue from {speaker}: "{sentence}". Generate the next line of dialogue with the approriate speaker that expresses this specific emotion ({emotion}):""" + + print(prompt) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + if use_base_model: + outputs = model.generate(**inputs, temperature=temp, do_sample=True) + else: + outputs = model.generate(input_ids=inputs["input_ids"], max_length=64, num_beams=1, early_stopping=True, do_sample=True, temperature=0.8, eos_token_id=tokenizer.eos_token_id) + + response = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(response) + return {"response": response} \ No newline at end of file