diff --git a/requirements.txt b/requirements.txt index 613aac9..78ff32f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,23 +4,16 @@ aiohttp==3.11.18 aiosignal==1.3.2 annotated-types==0.7.0 anyio==4.9.0 -asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work attrs==25.3.0 bitsandbytes==0.45.5 certifi==2025.4.26 charset-normalizer==3.4.2 click==8.2.1 -colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work -comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work contourpy==1.3.2 cycler==0.12.1 datasets==3.5.1 -debugpy @ file:///C:/b/abs_bf9oo2vhxp/croot/debugpy_1736269476451/work -decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work dill==0.3.7 einops==0.7.0 -exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1733208806608/work -executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work fastapi==0.116.0 filelock==3.18.0 fonttools==4.57.0 @@ -29,73 +22,47 @@ fsspec==2023.10.0 h11==0.16.0 huggingface-hub==0.30.2 idna==3.10 -importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1737420181517/work -ipykernel @ file:///D:/bld/ipykernel_1719845595208/work -ipython @ file:///D:/bld/bld/rattler-build_ipython_1745672185/work -ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work ipywidgets==8.1.7 -jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work Jinja2==3.1.6 -jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work -jupyter_core @ file:///D:/bld/jupyter_core_1710257313664/work jupyterlab_widgets==3.0.15 kiwisolver==1.4.8 MarkupSafe==3.0.2 matplotlib==3.10.1 -matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work mpmath==1.3.0 multidict==6.4.3 multiprocess==0.70.15 -nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work networkx==3.4.2 numpy==2.2.5 -packaging @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_packaging_1745345660/work pandas==2.2.3 -parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work peft==0.15.2 -pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work pillow==11.2.1 -platformdirs @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_platformdirs_1742485085/work -prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work propcache==0.3.1 -psutil @ file:///C:/b/abs_b5gv3mn55h/croot/psutil_1736371546320/work -pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work pyarrow==20.0.0 pyarrow-hotfix==0.7 pydantic==2.11.7 pydantic_core==2.33.2 -Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1736243443484/work pyparsing==3.2.3 -python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work pytz==2025.2 pywin32==308 PyYAML==6.0.2 -pyzmq @ file:///D:/bld/pyzmq_1666828541352/work regex==2024.11.6 requests==2.32.3 safetensors==0.5.3 scipy==1.15.2 seaborn==0.13.2 -six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work sniffio==1.3.1 -stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work starlette==0.46.2 sympy==1.14.0 tokenizers==0.21.1 torch==2.7.0+cu128 torchaudio==2.7.0+cu128 torchvision==0.22.0+cu128 -tornado @ file:///D:/bld/tornado_1666788735597/work tqdm==4.67.1 -traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work transformers==4.51.3 typing-inspection==0.4.1 -typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1744302253/work tzdata==2025.2 urllib3==2.4.0 uvicorn==0.35.0 -wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work widgetsnbextension==4.0.14 xxhash==3.5.0 yarl==1.20.0 -zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work diff --git a/server.py b/server.py index d821978..05857c6 100644 --- a/server.py +++ b/server.py @@ -2,6 +2,8 @@ from fastapi import FastAPI, Request from pydantic import BaseModel from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM import torch +from peft import PeftModel + ## Start the server with this command: uvicorn server:app --reload @@ -16,8 +18,9 @@ if use_base_model: model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").to("cuda" if torch.cuda.is_available() else "cpu") else: # Load custom model from Hugging Face Hub - model_name = "Cailean/macbeth" - model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu") + model_name = "Cailean/MacbethPEFT" + base_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small") + model = PeftModel.from_pretrained(base_model, model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) class Query(BaseModel):