MS Final Oral Exam: Aprameya Madhwaraj

MS Final Oral Exam: Aprameya Madhwaraj

Apr 20, 2026 - 8:00 AM
to , -

Automated PyTorch-to-JAX Model Translation with LLM Agents

Manual translation of deep learning models from research and prototyping environments, such as PyTorch, to high-performance deployment frameworks like JAX is a time-intensive and error-prone process. This project addresses the challenge by developing and evaluating an agent-based system utilizing Large Language Models (LLMs) for the automatic conversion of PyTorch code to executable JAX. The core objective is to ensure the preservation of model architecture, tensor behavior, training dynamics, and consistent output similarity during the translation process. A comparative benchmark was conducted on several state-of-the-art LLMs, including CodeGemma-7B and various Mistral models, with Devstral-Small-2-24B demonstrating the most robust understanding of the target framework. Systematic error logging was used to identify recurrent API fidelity and framework-paradigm issues, informing the design of a supervised instruction fine-tuning strategy. This approach, leveraging finetuning techniques for resource-efficient training, is specifically tailored to mitigate repeated syntax and API mapping errors. The results support the implementation of a full agent system capable of accepting raw PyTorch code, automatically repairing common translation errors, and incorporating self-testing mechanisms to validate code correctness prior to final export.

Committee: Simanta Mitra (major professor)