add /detect route for language detection
This commit is contained in:
parent
4ee081df66
commit
4309b55aa4
|
@ -1,4 +1,5 @@
|
||||||
use anyhow::Error;
|
use anyhow::Error;
|
||||||
|
use std::collections::HashMap;
|
||||||
use silicon::formatter::{ImageFormatter, ImageFormatterBuilder};
|
use silicon::formatter::{ImageFormatter, ImageFormatterBuilder};
|
||||||
use silicon::utils::{Background, ShadowAdder};
|
use silicon::utils::{Background, ShadowAdder};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
@ -11,6 +12,15 @@ use crate::rgba::{ImageRgba, Rgba};
|
||||||
type FontList = Vec<(String, f32)>;
|
type FontList = Vec<(String, f32)>;
|
||||||
type Lines = Vec<u32>;
|
type Lines = Vec<u32>;
|
||||||
|
|
||||||
|
macro_rules! unwrap_or_return {
|
||||||
|
( $e:expr, $r:expr ) => {
|
||||||
|
match $e {
|
||||||
|
Ok(x) => x,
|
||||||
|
Err(_) => return $r,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize)]
|
#[derive(Debug, serde::Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
/// Background image URL
|
/// Background image URL
|
||||||
|
@ -132,15 +142,25 @@ impl Config {
|
||||||
ps.find_syntax_by_first_line(first_line).unwrap_or_else(|| {
|
ps.find_syntax_by_first_line(first_line).unwrap_or_else(|| {
|
||||||
// Try using tensorflow to detect the language
|
// Try using tensorflow to detect the language
|
||||||
let input_data = Tensor::new(&[1]).with_values(&[self.code.clone()]).unwrap();
|
let input_data = Tensor::new(&[1]).with_values(&[self.code.clone()]).unwrap();
|
||||||
self.predict_language_with_tensorflow(ps, input_data)
|
let predictions = self.predict_language_with_tensorflow(ps, input_data).unwrap();
|
||||||
.unwrap_or_else(|_| ps.find_syntax_by_token("log").unwrap())
|
|
||||||
|
let mut max_score = -std::f32::INFINITY;
|
||||||
|
let mut max_language = "log";
|
||||||
|
for (language, score) in &predictions { // Borrow predictions here
|
||||||
|
if *score > max_score {
|
||||||
|
max_score = *score;
|
||||||
|
max_language = language;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ps.find_syntax_by_token(max_language).unwrap_or_else(|| ps.find_syntax_by_token("log").unwrap())
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
Ok(language)
|
Ok(language)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn predict_language_with_tensorflow<'a>(&self, ps: &'a SyntaxSet, input_data: Tensor<String>) -> Result<&'a SyntaxReference, Error> {
|
pub fn predict_language_with_tensorflow<'a>(&self, ps: &'a SyntaxSet, input_data: Tensor<String>) -> Result<HashMap<String, f32>, Error> {
|
||||||
if self.tf_model_graph.is_none() || self.tf_model.is_none() {
|
if self.tf_model_graph.is_none() || self.tf_model.is_none() {
|
||||||
return Err(Error::msg("TensorFlow model not loaded"));
|
return Err(Error::msg("TensorFlow model not loaded"));
|
||||||
}
|
}
|
||||||
|
@ -165,14 +185,14 @@ impl Config {
|
||||||
|
|
||||||
let classes: Tensor<String> = args.fetch(output_token_classes)?;
|
let classes: Tensor<String> = args.fetch(output_token_classes)?;
|
||||||
|
|
||||||
// Find the index of the highest score
|
let mut result: HashMap<String, f32> = HashMap::new();
|
||||||
let max_index = scores.iter().enumerate().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()).unwrap().0;
|
for (i, score) in scores.iter().enumerate() {
|
||||||
|
let class = classes[i].clone();
|
||||||
|
let log_score = score.log2();
|
||||||
|
result.insert(class, log_score);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
let language = &classes[max_index];
|
|
||||||
let language = ps.find_syntax_by_token(language).unwrap();
|
|
||||||
|
|
||||||
Ok(language)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
62
src/main.rs
62
src/main.rs
|
@ -7,6 +7,7 @@ use anyhow::Error;
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use silicon as si;
|
use silicon as si;
|
||||||
use silicon::utils::ToRgba;
|
use silicon::utils::ToRgba;
|
||||||
|
use tensorflow::Tensor;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use std::num::ParseIntError;
|
use std::num::ParseIntError;
|
||||||
|
@ -20,6 +21,7 @@ lazy_static! {
|
||||||
static ref HIGHLIGHTING_ASSETS: si::assets::HighlightingAssets =
|
static ref HIGHLIGHTING_ASSETS: si::assets::HighlightingAssets =
|
||||||
silicon::assets::HighlightingAssets::new();
|
silicon::assets::HighlightingAssets::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! unwrap_or_return {
|
macro_rules! unwrap_or_return {
|
||||||
( $e:expr, $r:expr ) => {
|
( $e:expr, $r:expr ) => {
|
||||||
match $e {
|
match $e {
|
||||||
|
@ -87,6 +89,12 @@ async fn help() -> impl Responder {
|
||||||
"GET /themes": "Return a list of available syntax themes.",
|
"GET /themes": "Return a list of available syntax themes.",
|
||||||
"GET /languages": "Retuns a list of languages which can be parsed.",
|
"GET /languages": "Retuns a list of languages which can be parsed.",
|
||||||
"GET /fonts": "Returns a list of available fonts.",
|
"GET /fonts": "Returns a list of available fonts.",
|
||||||
|
"GET /detect": {
|
||||||
|
"description": "Detect the language of the given code.",
|
||||||
|
"parameters": {
|
||||||
|
"code": "The code to detect the language of. Required."
|
||||||
|
}
|
||||||
|
},
|
||||||
"GET /generate": {
|
"GET /generate": {
|
||||||
"description": "Generate an image from the given code.",
|
"description": "Generate an image from the given code.",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
|
@ -150,6 +158,59 @@ async fn fonts() -> impl Responder {
|
||||||
HttpResponse::Ok().json(fonts)
|
HttpResponse::Ok().json(fonts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[get("/detect")]
|
||||||
|
async fn detect(info: web::Query<config::ConfigQuery>) -> impl Responder {
|
||||||
|
let args = CliArgs::parse();
|
||||||
|
let ha = &*HIGHLIGHTING_ASSETS;
|
||||||
|
|
||||||
|
let (ps, _ts) = (&ha.syntax_set, &ha.theme_set);
|
||||||
|
|
||||||
|
let mut conf = config::Config::default();
|
||||||
|
conf.code = info.code.clone();
|
||||||
|
if conf.code.is_empty() {
|
||||||
|
return HttpResponse::BadRequest()
|
||||||
|
.append_header(("Content-Type", "application/json"))
|
||||||
|
.body(r#"{"error": "code parameter is required"}"#);
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.tensorflow_model_dir.is_some() {
|
||||||
|
conf.load_tensorflow_model(args.tensorflow_model_dir.unwrap().as_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
let input_data = Tensor::new(&[1]).with_values(&[conf.code.clone()]).unwrap();
|
||||||
|
let predictions = unwrap_or_return!(
|
||||||
|
conf.predict_language_with_tensorflow(ps, input_data),
|
||||||
|
HttpResponse::BadRequest()
|
||||||
|
.append_header(("Content-Type", "application/json"))
|
||||||
|
.body(r#"{"error": "Failed to detect language."}"#)
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut sorted_predictions: Vec<_> = predictions.iter().collect();
|
||||||
|
sorted_predictions.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||||
|
|
||||||
|
let min_score = predictions.iter().map(|(_, score)| *score).fold(f32::INFINITY, f32::min);
|
||||||
|
let max_score = predictions.iter().map(|(_, score)| *score).fold(f32::NEG_INFINITY, f32::max);
|
||||||
|
|
||||||
|
// Normalize scores and pick top 5
|
||||||
|
let mut normalized_predictions: Vec<_> = predictions.iter().map(|(lang, score)| {
|
||||||
|
let normalized_score = (score - min_score) / (max_score - min_score) * 100.0;
|
||||||
|
(lang, normalized_score)
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
normalized_predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||||
|
|
||||||
|
let response = normalized_predictions
|
||||||
|
.iter()
|
||||||
|
// .take(5)
|
||||||
|
.map(|(language, score)| format!("{{\"language\": \"{}\", \"score\": {}}}", language, score))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(",");
|
||||||
|
|
||||||
|
HttpResponse::Ok()
|
||||||
|
.append_header(("Content-Type", "application/json"))
|
||||||
|
.body(format!("[{}]", response))
|
||||||
|
}
|
||||||
|
|
||||||
#[get("/generate")]
|
#[get("/generate")]
|
||||||
async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
|
async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
|
||||||
let args = CliArgs::parse();
|
let args = CliArgs::parse();
|
||||||
|
@ -290,6 +351,7 @@ async fn main() -> std::io::Result<()> {
|
||||||
.service(themes)
|
.service(themes)
|
||||||
.service(languages)
|
.service(languages)
|
||||||
.service(fonts)
|
.service(fonts)
|
||||||
|
.service(detect)
|
||||||
.service(generate)
|
.service(generate)
|
||||||
})
|
})
|
||||||
.bind((host.clone(), port.parse::<u16>().unwrap()))?
|
.bind((host.clone(), port.parse::<u16>().unwrap()))?
|
||||||
|
|
Loading…
Reference in New Issue