truncation true

This commit is contained in:
Maurizio Dipierro 2025-03-19 17:06:26 +01:00
parent 6023f099f8
commit 06a5c59a47
1 changed files with 8 additions and 5 deletions

13
app.py
View File

@ -41,10 +41,11 @@ def normalize_values(text, target_max=500):
return normalized_text return normalized_text
processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview") processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview")
model = AutoModelForVision2Seq.from_pretrained("ds4sd/SmolDocling-256M-preview", model = AutoModelForVision2Seq.from_pretrained(
torch_dtype=torch.bfloat16, "ds4sd/SmolDocling-256M-preview",
# _attn_implementation="flash_attention_2" torch_dtype=torch.bfloat16,
).to("cuda") # _attn_implementation="flash_attention_2"
).to("cuda")
def model_inference(input_dict, history): def model_inference(input_dict, history):
text = input_dict["text"] text = input_dict["text"]
@ -77,7 +78,9 @@ def model_inference(input_dict, history):
} }
] ]
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True) prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[images], return_tensors="pt").to('cuda')
# Added truncation=True to explicitly activate truncation.
inputs = processor(text=prompt, images=[images], return_tensors="pt", truncation=True).to('cuda')
generation_args = { generation_args = {
"input_ids": inputs.input_ids, "input_ids": inputs.input_ids,